Fix reconnect, add auth on pool creation
Support pool creation from url including db, username, and password
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from asyncio_mongo.auth import CredentialsAuthenticator
|
||||
from asyncio_mongo.log import logger
|
||||
from asyncio_mongo.database import Database
|
||||
from .protocol import MongoProtocol
|
||||
@@ -24,7 +25,8 @@ class Connection:
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def create(cls, host='localhost', port=27017, loop=None, auto_reconnect=True):
|
||||
def create(cls, host='localhost', port=27017, db=None, username=None, password=None, loop=None,
|
||||
auto_reconnect=True):
|
||||
connection = cls()
|
||||
|
||||
connection.host = host
|
||||
@@ -33,15 +35,17 @@ class Connection:
|
||||
connection._retry_interval = .5
|
||||
|
||||
# Create protocol instance
|
||||
protocol_factory = type('MongoProtocol', (cls.protocol,), {})
|
||||
def connection_lost(exc):
|
||||
if auto_reconnect:
|
||||
logger.info("Connection lost, attempting to reconnect")
|
||||
asyncio.Task(connection._reconnect())
|
||||
|
||||
if auto_reconnect:
|
||||
class protocol_factory(protocol_factory):
|
||||
def connection_lost(self, exc):
|
||||
super().connection_lost(exc)
|
||||
asyncio.Task(connection._reconnect())
|
||||
connection._authenticators = {}
|
||||
connection._db_name = db
|
||||
if db and username:
|
||||
connection._authenticators[db] = CredentialsAuthenticator(username, password)
|
||||
|
||||
connection.protocol = protocol_factory()
|
||||
connection.protocol = MongoProtocol(connection_lost_callback=connection_lost)
|
||||
|
||||
# Connect
|
||||
yield from connection._reconnect()
|
||||
@@ -53,6 +57,10 @@ class Connection:
|
||||
if self.transport:
|
||||
return self.transport.close()
|
||||
|
||||
@property
|
||||
def default_database(self):
|
||||
return self[self._db_name]
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
""" The transport instance that the protocol is currently using. """
|
||||
@@ -77,9 +85,12 @@ class Connection:
|
||||
loop = self._loop or asyncio.get_event_loop()
|
||||
while True:
|
||||
try:
|
||||
logger.log(logging.INFO, 'Connecting to mongo')
|
||||
logger.log(logging.INFO, 'Connecting to mongo at {host}:{port}'.format(host=self.host, port=self.port))
|
||||
yield from loop.create_connection(lambda: self.protocol, self.host, self.port)
|
||||
self._reset_retry_interval()
|
||||
for name, auth in self._authenticators.items():
|
||||
yield from auth(Database(self, name))
|
||||
logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name))
|
||||
return
|
||||
except OSError:
|
||||
# Sleep and try again
|
||||
|
||||
Reference in New Issue
Block a user