Fix reconnect, add auth on pool creation

Support pool creation from url including db, username, and password
This commit is contained in:
Don Brown
2014-02-23 11:22:33 -08:00
parent c7ad1b1ba0
commit ca538a08ce
4 changed files with 74 additions and 19 deletions

17
asyncio_mongo/auth.py Normal file
View File

@@ -0,0 +1,17 @@
import asyncio
class Authenticator:
def authenticate(self, db):
raise NotImplementedError()
class CredentialsAuthenticator():
def __init__(self, username, password):
self.username = username
self.password = password
@asyncio.coroutine
def authenticate(self, db):
yield from db.authenticate(self.username, self.password)

View File

@@ -1,4 +1,5 @@
import logging import logging
from asyncio_mongo.auth import CredentialsAuthenticator
from asyncio_mongo.log import logger from asyncio_mongo.log import logger
from asyncio_mongo.database import Database from asyncio_mongo.database import Database
from .protocol import MongoProtocol from .protocol import MongoProtocol
@@ -24,7 +25,8 @@ class Connection:
@classmethod @classmethod
@asyncio.coroutine @asyncio.coroutine
def create(cls, host='localhost', port=27017, loop=None, auto_reconnect=True): def create(cls, host='localhost', port=27017, db=None, username=None, password=None, loop=None,
auto_reconnect=True):
connection = cls() connection = cls()
connection.host = host connection.host = host
@@ -33,15 +35,17 @@ class Connection:
connection._retry_interval = .5 connection._retry_interval = .5
# Create protocol instance # Create protocol instance
protocol_factory = type('MongoProtocol', (cls.protocol,), {}) def connection_lost(exc):
if auto_reconnect:
logger.info("Connection lost, attempting to reconnect")
asyncio.Task(connection._reconnect())
if auto_reconnect: connection._authenticators = {}
class protocol_factory(protocol_factory): connection._db_name = db
def connection_lost(self, exc): if db and username:
super().connection_lost(exc) connection._authenticators[db] = CredentialsAuthenticator(username, password)
asyncio.Task(connection._reconnect())
connection.protocol = protocol_factory() connection.protocol = MongoProtocol(connection_lost_callback=connection_lost)
# Connect # Connect
yield from connection._reconnect() yield from connection._reconnect()
@@ -53,6 +57,10 @@ class Connection:
if self.transport: if self.transport:
return self.transport.close() return self.transport.close()
@property
def default_database(self):
return self[self._db_name]
@property @property
def transport(self): def transport(self):
""" The transport instance that the protocol is currently using. """ """ The transport instance that the protocol is currently using. """
@@ -77,9 +85,12 @@ class Connection:
loop = self._loop or asyncio.get_event_loop() loop = self._loop or asyncio.get_event_loop()
while True: while True:
try: try:
logger.log(logging.INFO, 'Connecting to mongo') logger.log(logging.INFO, 'Connecting to mongo at {host}:{port}'.format(host=self.host, port=self.port))
yield from loop.create_connection(lambda: self.protocol, self.host, self.port) yield from loop.create_connection(lambda: self.protocol, self.host, self.port)
self._reset_retry_interval() self._reset_retry_interval()
for name, auth in self._authenticators.items():
yield from auth(Database(self, name))
logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name))
return return
except OSError: except OSError:
# Sleep and try again # Sleep and try again

View File

@@ -1,3 +1,4 @@
from urllib.parse import urlparse
from .connection import Connection from .connection import Connection
from .exceptions import NoAvailableConnectionsInPoolError from .exceptions import NoAvailableConnectionsInPoolError
from .protocol import MongoProtocol from .protocol import MongoProtocol
@@ -38,13 +39,29 @@ class Pool:
@classmethod @classmethod
@asyncio.coroutine @asyncio.coroutine
def create(cls, host='localhost', port=27017, loop=None, poolsize=1, auto_reconnect=True): def create(cls, host='localhost', port=27017, db=None, username=None, password=None, url=None, loop=None,
poolsize=1, auto_reconnect=True):
""" """
Create a new pool instance. Create a new pool instance.
""" """
self = cls() self = cls()
self._host = host
self._port = port if url:
url = urlparse(url)
try:
db = url.path.replace('/', '')
except (AttributeError, ValueError):
raise Exception("Missing database name in URI")
self._host = url.hostname
self._port = url.port or 27017
username = url.username
password = url.password
else:
self._host = host
self._port = port
self._pool_size = poolsize self._pool_size = poolsize
# Create connections # Create connections
@@ -52,7 +69,8 @@ class Pool:
for i in range(poolsize): for i in range(poolsize):
connection_class = cls.get_connection_class() connection_class = cls.get_connection_class()
connection = yield from connection_class.create(host=host, port=port, loop=loop, connection = yield from connection_class.create(host=host, port=port, db=db, username=username,
password=password, loop=loop,
auto_reconnect=auto_reconnect) auto_reconnect=auto_reconnect)
self._connections.append(connection) self._connections.append(connection)
@@ -101,13 +119,10 @@ class Pool:
busy in a blocking request or transaction.) busy in a blocking request or transaction.)
""" """
if 'close' == name:
return self.close
connection = self._get_free_connection() connection = self._get_free_connection()
if connection: if connection:
return getattr(connection, name) return getattr(connection, name)
else: else:
raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, connected=%s' % ( raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, connected=%s' % (
self.pool_size, self.connections_connected)) self._pool_size, self.connections_connected))

View File

@@ -16,6 +16,7 @@ import logging
import struct import struct
import asyncio import asyncio
from asyncio_mongo.database import Database
from asyncio_mongo.exceptions import ConnectionLostError from asyncio_mongo.exceptions import ConnectionLostError
import asyncio_mongo._bson as bson import asyncio_mongo._bson as bson
from asyncio_mongo.log import logger from asyncio_mongo.log import logger
@@ -27,7 +28,7 @@ _ZERO = b"\x00\x00\x00\x00"
class _MongoQuery(object): class _MongoQuery(object):
def __init__(self, id, collection, limit): def __init__(self, id, collection, limit):
self.id = id self.id = id
self.limit = limit self.limit = limit
self.collection = collection self.collection = collection
@@ -36,22 +37,30 @@ class _MongoQuery(object):
class MongoProtocol(asyncio.Protocol): class MongoProtocol(asyncio.Protocol):
def __init__(self): def __init__(self, connection_lost_callback=None, authenticators=None):
self.__id = 0 self.__id = 0
self.__buffer = b"" self.__buffer = b""
self.__queries = {} self.__queries = {}
self.__datalen = None self.__datalen = None
self.__response = 0 self.__response = 0
self.__waiting_header = True self.__waiting_header = True
self.__connection_lost_callback = connection_lost_callback
self._pipelined_calls = set() # Set of all the pipelined calls. self._pipelined_calls = set() # Set of all the pipelined calls.
self.transport = None self.transport = None
self._is_connected = False self._is_connected = False
self.__authenticators = authenticators or {}
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport self.transport = transport
self._is_connected = True self._is_connected = True
logger.log(logging.INFO, 'Mongo connection made') logger.log(logging.INFO, 'Mongo connection made')
# for name, auth in self.__authenticators.iter_items():
# yield from auth(Database(self, name))
# logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name))
def connection_lost(self, exc): def connection_lost(self, exc):
self._is_connected = False self._is_connected = False
self.transport = None self.transport = None
@@ -62,6 +71,9 @@ class MongoProtocol(asyncio.Protocol):
logger.log(logging.INFO, 'Mongo connection lost') logger.log(logging.INFO, 'Mongo connection lost')
if self.__connection_lost_callback:
self.__connection_lost_callback(exec)
@property @property
def is_connected(self): def is_connected(self):
""" True when the underlying transport is connected. """ """ True when the underlying transport is connected. """