From ca538a08ce7998c6ab2e17e212cd78af79fae71e Mon Sep 17 00:00:00 2001 From: Don Brown Date: Sun, 23 Feb 2014 11:22:33 -0800 Subject: [PATCH] Fix reconnect, add auth on pool creation Support pool creation from url including db, username, and password --- asyncio_mongo/auth.py | 17 +++++++++++++++++ asyncio_mongo/connection.py | 29 ++++++++++++++++++++--------- asyncio_mongo/pool.py | 31 +++++++++++++++++++++++-------- asyncio_mongo/protocol.py | 16 ++++++++++++++-- 4 files changed, 74 insertions(+), 19 deletions(-) create mode 100644 asyncio_mongo/auth.py diff --git a/asyncio_mongo/auth.py b/asyncio_mongo/auth.py new file mode 100644 index 0000000..b89c516 --- /dev/null +++ b/asyncio_mongo/auth.py @@ -0,0 +1,17 @@ +import asyncio + + +class Authenticator: + + def authenticate(self, db): + raise NotImplementedError() + + +class CredentialsAuthenticator(): + def __init__(self, username, password): + self.username = username + self.password = password + + @asyncio.coroutine + def authenticate(self, db): + yield from db.authenticate(self.username, self.password) \ No newline at end of file diff --git a/asyncio_mongo/connection.py b/asyncio_mongo/connection.py index 241d66b..88b3ebd 100644 --- a/asyncio_mongo/connection.py +++ b/asyncio_mongo/connection.py @@ -1,4 +1,5 @@ import logging +from asyncio_mongo.auth import CredentialsAuthenticator from asyncio_mongo.log import logger from asyncio_mongo.database import Database from .protocol import MongoProtocol @@ -24,7 +25,8 @@ class Connection: @classmethod @asyncio.coroutine - def create(cls, host='localhost', port=27017, loop=None, auto_reconnect=True): + def create(cls, host='localhost', port=27017, db=None, username=None, password=None, loop=None, + auto_reconnect=True): connection = cls() connection.host = host @@ -33,15 +35,17 @@ class Connection: connection._retry_interval = .5 # Create protocol instance - protocol_factory = type('MongoProtocol', (cls.protocol,), {}) + def connection_lost(exc): + if auto_reconnect: + logger.info("Connection lost, attempting to reconnect") + asyncio.Task(connection._reconnect()) - if auto_reconnect: - class protocol_factory(protocol_factory): - def connection_lost(self, exc): - super().connection_lost(exc) - asyncio.Task(connection._reconnect()) + connection._authenticators = {} + connection._db_name = db + if db and username: + connection._authenticators[db] = CredentialsAuthenticator(username, password) - connection.protocol = protocol_factory() + connection.protocol = MongoProtocol(connection_lost_callback=connection_lost) # Connect yield from connection._reconnect() @@ -53,6 +57,10 @@ class Connection: if self.transport: return self.transport.close() + @property + def default_database(self): + return self[self._db_name] + @property def transport(self): """ The transport instance that the protocol is currently using. """ @@ -77,9 +85,12 @@ class Connection: loop = self._loop or asyncio.get_event_loop() while True: try: - logger.log(logging.INFO, 'Connecting to mongo') + logger.log(logging.INFO, 'Connecting to mongo at {host}:{port}'.format(host=self.host, port=self.port)) yield from loop.create_connection(lambda: self.protocol, self.host, self.port) self._reset_retry_interval() + for name, auth in self._authenticators.items(): + yield from auth(Database(self, name)) + logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name)) return except OSError: # Sleep and try again diff --git a/asyncio_mongo/pool.py b/asyncio_mongo/pool.py index 5e6fc09..1bd3c00 100644 --- a/asyncio_mongo/pool.py +++ b/asyncio_mongo/pool.py @@ -1,3 +1,4 @@ +from urllib.parse import urlparse from .connection import Connection from .exceptions import NoAvailableConnectionsInPoolError from .protocol import MongoProtocol @@ -38,13 +39,29 @@ class Pool: @classmethod @asyncio.coroutine - def create(cls, host='localhost', port=27017, loop=None, poolsize=1, auto_reconnect=True): + def create(cls, host='localhost', port=27017, db=None, username=None, password=None, url=None, loop=None, + poolsize=1, auto_reconnect=True): """ Create a new pool instance. """ self = cls() - self._host = host - self._port = port + + if url: + url = urlparse(url) + + try: + db = url.path.replace('/', '') + except (AttributeError, ValueError): + raise Exception("Missing database name in URI") + + self._host = url.hostname + self._port = url.port or 27017 + username = url.username + password = url.password + else: + self._host = host + self._port = port + self._pool_size = poolsize # Create connections @@ -52,7 +69,8 @@ class Pool: for i in range(poolsize): connection_class = cls.get_connection_class() - connection = yield from connection_class.create(host=host, port=port, loop=loop, + connection = yield from connection_class.create(host=host, port=port, db=db, username=username, + password=password, loop=loop, auto_reconnect=auto_reconnect) self._connections.append(connection) @@ -101,13 +119,10 @@ class Pool: busy in a blocking request or transaction.) """ - if 'close' == name: - return self.close - connection = self._get_free_connection() if connection: return getattr(connection, name) else: raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, connected=%s' % ( - self.pool_size, self.connections_connected)) + self._pool_size, self.connections_connected)) diff --git a/asyncio_mongo/protocol.py b/asyncio_mongo/protocol.py index 513d35e..1993ae2 100644 --- a/asyncio_mongo/protocol.py +++ b/asyncio_mongo/protocol.py @@ -16,6 +16,7 @@ import logging import struct import asyncio +from asyncio_mongo.database import Database from asyncio_mongo.exceptions import ConnectionLostError import asyncio_mongo._bson as bson from asyncio_mongo.log import logger @@ -27,7 +28,7 @@ _ZERO = b"\x00\x00\x00\x00" class _MongoQuery(object): - def __init__(self, id, collection, limit): + def __init__(self, id, collection, limit): self.id = id self.limit = limit self.collection = collection @@ -36,22 +37,30 @@ class _MongoQuery(object): class MongoProtocol(asyncio.Protocol): - def __init__(self): + def __init__(self, connection_lost_callback=None, authenticators=None): self.__id = 0 self.__buffer = b"" self.__queries = {} self.__datalen = None self.__response = 0 self.__waiting_header = True + self.__connection_lost_callback = connection_lost_callback self._pipelined_calls = set() # Set of all the pipelined calls. self.transport = None self._is_connected = False + self.__authenticators = authenticators or {} + def connection_made(self, transport): self.transport = transport + self._is_connected = True logger.log(logging.INFO, 'Mongo connection made') + # for name, auth in self.__authenticators.iter_items(): + # yield from auth(Database(self, name)) + # logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name)) + def connection_lost(self, exc): self._is_connected = False self.transport = None @@ -62,6 +71,9 @@ class MongoProtocol(asyncio.Protocol): logger.log(logging.INFO, 'Mongo connection lost') + if self.__connection_lost_callback: + self.__connection_lost_callback(exec) + @property def is_connected(self): """ True when the underlying transport is connected. """