Fix reconnect, add auth on pool creation
Support pool creation from url including db, username, and password
This commit is contained in:
17
asyncio_mongo/auth.py
Normal file
17
asyncio_mongo/auth.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class Authenticator:
|
||||||
|
|
||||||
|
def authenticate(self, db):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsAuthenticator():
|
||||||
|
def __init__(self, username, password):
|
||||||
|
self.username = username
|
||||||
|
self.password = password
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def authenticate(self, db):
|
||||||
|
yield from db.authenticate(self.username, self.password)
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from asyncio_mongo.auth import CredentialsAuthenticator
|
||||||
from asyncio_mongo.log import logger
|
from asyncio_mongo.log import logger
|
||||||
from asyncio_mongo.database import Database
|
from asyncio_mongo.database import Database
|
||||||
from .protocol import MongoProtocol
|
from .protocol import MongoProtocol
|
||||||
@@ -24,7 +25,8 @@ class Connection:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def create(cls, host='localhost', port=27017, loop=None, auto_reconnect=True):
|
def create(cls, host='localhost', port=27017, db=None, username=None, password=None, loop=None,
|
||||||
|
auto_reconnect=True):
|
||||||
connection = cls()
|
connection = cls()
|
||||||
|
|
||||||
connection.host = host
|
connection.host = host
|
||||||
@@ -33,15 +35,17 @@ class Connection:
|
|||||||
connection._retry_interval = .5
|
connection._retry_interval = .5
|
||||||
|
|
||||||
# Create protocol instance
|
# Create protocol instance
|
||||||
protocol_factory = type('MongoProtocol', (cls.protocol,), {})
|
def connection_lost(exc):
|
||||||
|
|
||||||
if auto_reconnect:
|
if auto_reconnect:
|
||||||
class protocol_factory(protocol_factory):
|
logger.info("Connection lost, attempting to reconnect")
|
||||||
def connection_lost(self, exc):
|
|
||||||
super().connection_lost(exc)
|
|
||||||
asyncio.Task(connection._reconnect())
|
asyncio.Task(connection._reconnect())
|
||||||
|
|
||||||
connection.protocol = protocol_factory()
|
connection._authenticators = {}
|
||||||
|
connection._db_name = db
|
||||||
|
if db and username:
|
||||||
|
connection._authenticators[db] = CredentialsAuthenticator(username, password)
|
||||||
|
|
||||||
|
connection.protocol = MongoProtocol(connection_lost_callback=connection_lost)
|
||||||
|
|
||||||
# Connect
|
# Connect
|
||||||
yield from connection._reconnect()
|
yield from connection._reconnect()
|
||||||
@@ -53,6 +57,10 @@ class Connection:
|
|||||||
if self.transport:
|
if self.transport:
|
||||||
return self.transport.close()
|
return self.transport.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_database(self):
|
||||||
|
return self[self._db_name]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transport(self):
|
def transport(self):
|
||||||
""" The transport instance that the protocol is currently using. """
|
""" The transport instance that the protocol is currently using. """
|
||||||
@@ -77,9 +85,12 @@ class Connection:
|
|||||||
loop = self._loop or asyncio.get_event_loop()
|
loop = self._loop or asyncio.get_event_loop()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
logger.log(logging.INFO, 'Connecting to mongo')
|
logger.log(logging.INFO, 'Connecting to mongo at {host}:{port}'.format(host=self.host, port=self.port))
|
||||||
yield from loop.create_connection(lambda: self.protocol, self.host, self.port)
|
yield from loop.create_connection(lambda: self.protocol, self.host, self.port)
|
||||||
self._reset_retry_interval()
|
self._reset_retry_interval()
|
||||||
|
for name, auth in self._authenticators.items():
|
||||||
|
yield from auth(Database(self, name))
|
||||||
|
logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name))
|
||||||
return
|
return
|
||||||
except OSError:
|
except OSError:
|
||||||
# Sleep and try again
|
# Sleep and try again
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from urllib.parse import urlparse
|
||||||
from .connection import Connection
|
from .connection import Connection
|
||||||
from .exceptions import NoAvailableConnectionsInPoolError
|
from .exceptions import NoAvailableConnectionsInPoolError
|
||||||
from .protocol import MongoProtocol
|
from .protocol import MongoProtocol
|
||||||
@@ -38,13 +39,29 @@ class Pool:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def create(cls, host='localhost', port=27017, loop=None, poolsize=1, auto_reconnect=True):
|
def create(cls, host='localhost', port=27017, db=None, username=None, password=None, url=None, loop=None,
|
||||||
|
poolsize=1, auto_reconnect=True):
|
||||||
"""
|
"""
|
||||||
Create a new pool instance.
|
Create a new pool instance.
|
||||||
"""
|
"""
|
||||||
self = cls()
|
self = cls()
|
||||||
|
|
||||||
|
if url:
|
||||||
|
url = urlparse(url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db = url.path.replace('/', '')
|
||||||
|
except (AttributeError, ValueError):
|
||||||
|
raise Exception("Missing database name in URI")
|
||||||
|
|
||||||
|
self._host = url.hostname
|
||||||
|
self._port = url.port or 27017
|
||||||
|
username = url.username
|
||||||
|
password = url.password
|
||||||
|
else:
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
|
||||||
self._pool_size = poolsize
|
self._pool_size = poolsize
|
||||||
|
|
||||||
# Create connections
|
# Create connections
|
||||||
@@ -52,7 +69,8 @@ class Pool:
|
|||||||
|
|
||||||
for i in range(poolsize):
|
for i in range(poolsize):
|
||||||
connection_class = cls.get_connection_class()
|
connection_class = cls.get_connection_class()
|
||||||
connection = yield from connection_class.create(host=host, port=port, loop=loop,
|
connection = yield from connection_class.create(host=host, port=port, db=db, username=username,
|
||||||
|
password=password, loop=loop,
|
||||||
auto_reconnect=auto_reconnect)
|
auto_reconnect=auto_reconnect)
|
||||||
self._connections.append(connection)
|
self._connections.append(connection)
|
||||||
|
|
||||||
@@ -101,13 +119,10 @@ class Pool:
|
|||||||
busy in a blocking request or transaction.)
|
busy in a blocking request or transaction.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if 'close' == name:
|
|
||||||
return self.close
|
|
||||||
|
|
||||||
connection = self._get_free_connection()
|
connection = self._get_free_connection()
|
||||||
|
|
||||||
if connection:
|
if connection:
|
||||||
return getattr(connection, name)
|
return getattr(connection, name)
|
||||||
else:
|
else:
|
||||||
raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, connected=%s' % (
|
raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, connected=%s' % (
|
||||||
self.pool_size, self.connections_connected))
|
self._pool_size, self.connections_connected))
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import logging
|
|||||||
|
|
||||||
import struct
|
import struct
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from asyncio_mongo.database import Database
|
||||||
from asyncio_mongo.exceptions import ConnectionLostError
|
from asyncio_mongo.exceptions import ConnectionLostError
|
||||||
import asyncio_mongo._bson as bson
|
import asyncio_mongo._bson as bson
|
||||||
from asyncio_mongo.log import logger
|
from asyncio_mongo.log import logger
|
||||||
@@ -36,22 +37,30 @@ class _MongoQuery(object):
|
|||||||
|
|
||||||
|
|
||||||
class MongoProtocol(asyncio.Protocol):
|
class MongoProtocol(asyncio.Protocol):
|
||||||
def __init__(self):
|
def __init__(self, connection_lost_callback=None, authenticators=None):
|
||||||
self.__id = 0
|
self.__id = 0
|
||||||
self.__buffer = b""
|
self.__buffer = b""
|
||||||
self.__queries = {}
|
self.__queries = {}
|
||||||
self.__datalen = None
|
self.__datalen = None
|
||||||
self.__response = 0
|
self.__response = 0
|
||||||
self.__waiting_header = True
|
self.__waiting_header = True
|
||||||
|
self.__connection_lost_callback = connection_lost_callback
|
||||||
self._pipelined_calls = set() # Set of all the pipelined calls.
|
self._pipelined_calls = set() # Set of all the pipelined calls.
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self._is_connected = False
|
self._is_connected = False
|
||||||
|
|
||||||
|
self.__authenticators = authenticators or {}
|
||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
self._is_connected = True
|
self._is_connected = True
|
||||||
logger.log(logging.INFO, 'Mongo connection made')
|
logger.log(logging.INFO, 'Mongo connection made')
|
||||||
|
|
||||||
|
# for name, auth in self.__authenticators.iter_items():
|
||||||
|
# yield from auth(Database(self, name))
|
||||||
|
# logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name))
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
self._is_connected = False
|
self._is_connected = False
|
||||||
self.transport = None
|
self.transport = None
|
||||||
@@ -62,6 +71,9 @@ class MongoProtocol(asyncio.Protocol):
|
|||||||
|
|
||||||
logger.log(logging.INFO, 'Mongo connection lost')
|
logger.log(logging.INFO, 'Mongo connection lost')
|
||||||
|
|
||||||
|
if self.__connection_lost_callback:
|
||||||
|
self.__connection_lost_callback(exec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
""" True when the underlying transport is connected. """
|
""" True when the underlying transport is connected. """
|
||||||
|
|||||||
Reference in New Issue
Block a user