commit ae154dbc7231cb534b6a242b2d0c435d08ad4c2e Author: Don Brown Date: Fri Jan 24 16:19:17 2014 -0700 Initial commit * Examples work * setup.py kinda updasted * Fork of txmongo but with new pymongo embedded diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..a281b97 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.DS_Store +*.swp +._* +*.pyc +*.cache +_trial_temp/ +*.pid +*.iml +*.ipr +*.iws +.idea +venv + diff --git a/asyncio_mongo/__init__.py b/asyncio_mongo/__init__.py new file mode 100644 index 0000000..b384bf0 --- /dev/null +++ b/asyncio_mongo/__init__.py @@ -0,0 +1,9 @@ +from .connection import * +from .exceptions import * +from .pool import * +from .protocol import * + +__doc__ = \ +""" +MongoDB protocol implementation for asyncio (PEP 3156) +""" diff --git a/asyncio_mongo/_bson/__init__.py b/asyncio_mongo/_bson/__init__.py new file mode 100644 index 0000000..848082b --- /dev/null +++ b/asyncio_mongo/_bson/__init__.py @@ -0,0 +1,616 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BSON (Binary JSON) encoding and decoding. +""" + +import calendar +import datetime +import re +import struct +import sys + +from asyncio_mongo._bson.binary import (Binary, OLD_UUID_SUBTYPE, + JAVA_LEGACY, CSHARP_LEGACY) +from asyncio_mongo._bson.code import Code +from asyncio_mongo._bson.dbref import DBRef +from asyncio_mongo._bson.errors import (InvalidBSON, + InvalidDocument, + InvalidStringData) +from asyncio_mongo._bson.max_key import MaxKey +from asyncio_mongo._bson.min_key import MinKey +from asyncio_mongo._bson.objectid import ObjectId +from asyncio_mongo._bson.py3compat import b, binary_type +from asyncio_mongo._bson.son import SON, RE_TYPE +from asyncio_mongo._bson.timestamp import Timestamp +from asyncio_mongo._bson.tz_util import utc + + +try: + from asyncio_mongo._bson import _cbson + _use_c = True +except ImportError: + _use_c = False + +try: + import uuid + _use_uuid = True +except ImportError: + _use_uuid = False + +PY3 = sys.version_info[0] == 3 + + +MAX_INT32 = 2147483647 +MIN_INT32 = -2147483648 +MAX_INT64 = 9223372036854775807 +MIN_INT64 = -9223372036854775808 + +EPOCH_AWARE = datetime.datetime.fromtimestamp(0, utc) +EPOCH_NAIVE = datetime.datetime.utcfromtimestamp(0) + +# Create constants compatible with all versions of +# python from 2.4 forward. In 2.x b("foo") is just +# "foo". In 3.x it becomes b"foo". +EMPTY = b("") +ZERO = b("\x00") +ONE = b("\x01") + +BSONNUM = b("\x01") # Floating point +BSONSTR = b("\x02") # UTF-8 string +BSONOBJ = b("\x03") # Embedded document +BSONARR = b("\x04") # Array +BSONBIN = b("\x05") # Binary +BSONUND = b("\x06") # Undefined +BSONOID = b("\x07") # ObjectId +BSONBOO = b("\x08") # Boolean +BSONDAT = b("\x09") # UTC Datetime +BSONNUL = b("\x0A") # Null +BSONRGX = b("\x0B") # Regex +BSONREF = b("\x0C") # DBRef +BSONCOD = b("\x0D") # Javascript code +BSONSYM = b("\x0E") # Symbol +BSONCWS = b("\x0F") # Javascript code with scope +BSONINT = b("\x10") # 32bit int +BSONTIM = b("\x11") # Timestamp +BSONLON = b("\x12") # 64bit int +BSONMIN = b("\xFF") # Min key +BSONMAX = b("\x7F") # Max key + + +def _get_int(data, position, as_class=None, + tz_aware=False, uuid_subtype=OLD_UUID_SUBTYPE, unsigned=False): + format = unsigned and "I" or "i" + try: + value = struct.unpack("<%s" % format, data[position:position + 4])[0] + except struct.error: + raise InvalidBSON() + position += 4 + return value, position + + +def _get_c_string(data, position, length=None): + if length is None: + try: + end = data.index(ZERO, position) + except ValueError: + raise InvalidBSON() + else: + end = position + length + value = data[position:end].decode("utf-8") + position = end + 1 + + return value, position + + +def _make_c_string(string, check_null=False): + if isinstance(string, str): + if check_null and "\x00" in string: + raise InvalidDocument("BSON keys / regex patterns must not " + "contain a NULL character") + return string.encode("utf-8") + ZERO + else: + if check_null and ZERO in string: + raise InvalidDocument("BSON keys / regex patterns must not " + "contain a NULL character") + try: + string.decode("utf-8") + return string + ZERO + except UnicodeError: + raise InvalidStringData("strings in documents must be valid " + "UTF-8: %r" % string) + + +def _get_number(data, position, as_class, tz_aware, uuid_subtype): + num = struct.unpack(" MAX_INT64 or value < MIN_INT64: + raise OverflowError("BSON can only handle up to 8-byte ints") + if value > MAX_INT32 or value < MIN_INT32: + return BSONLON + name + struct.pack(" MAX_INT64 or value < MIN_INT64: + raise OverflowError("BSON can only handle up to 8-byte ints") + return BSONLON + name + struct.pack("`_ + to use + """ + + def __new__(cls, data, subtype=BINARY_SUBTYPE): + if not isinstance(data, binary_type): + raise TypeError("data must be an " + "instance of %s" % (binary_type.__name__,)) + if not isinstance(subtype, int): + raise TypeError("subtype must be an instance of int") + if subtype >= 256 or subtype < 0: + raise ValueError("subtype must be contained in [0, 256)") + self = binary_type.__new__(cls, data) + self.__subtype = subtype + return self + + @property + def subtype(self): + """Subtype of this binary data. + """ + return self.__subtype + + def __getnewargs__(self): + # Work around http://bugs.python.org/issue7382 + data = super(Binary, self).__getnewargs__()[0] + if PY3 and not isinstance(data, binary_type): + data = data.encode('latin-1') + return data, self.__subtype + + def __eq__(self, other): + if isinstance(other, Binary): + return ((self.__subtype, binary_type(self)) == + (other.subtype, binary_type(other))) + # We don't return NotImplemented here because if we did then + # Binary("foo") == "foo" would return True, since Binary is a + # subclass of str... + return False + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "Binary(%s, %s)" % (binary_type.__repr__(self), self.__subtype) + + +class UUIDLegacy(Binary): + """UUID wrapper to support working with UUIDs stored as legacy + BSON binary subtype 3. + + .. doctest:: + + >>> import uuid + >>> from asyncio_mongo._bson.binary import Binary, UUIDLegacy, UUID_SUBTYPE + >>> my_uuid = uuid.uuid4() + >>> coll = db.test + >>> coll.uuid_subtype = UUID_SUBTYPE + >>> coll.insert({'uuid': Binary(my_uuid.bytes, 3)}) + ObjectId('...') + >>> coll.find({'uuid': my_uuid}).count() + 0 + >>> coll.find({'uuid': UUIDLegacy(my_uuid)}).count() + 1 + >>> coll.find({'uuid': UUIDLegacy(my_uuid)})[0]['uuid'] + UUID('...') + >>> + >>> # Convert from subtype 3 to subtype 4 + >>> doc = coll.find_one({'uuid': UUIDLegacy(my_uuid)}) + >>> coll.save(doc) + ObjectId('...') + >>> coll.find({'uuid': UUIDLegacy(my_uuid)}).count() + 0 + >>> coll.find({'uuid': {'$in': [UUIDLegacy(my_uuid), my_uuid]}}).count() + 1 + >>> coll.find_one({'uuid': my_uuid})['uuid'] + UUID('...') + + Raises TypeError if `obj` is not an instance of :class:`~uuid.UUID`. + + :Parameters: + - `obj`: An instance of :class:`~uuid.UUID`. + """ + + def __new__(cls, obj): + if not isinstance(obj, UUID): + raise TypeError("obj must be an instance of uuid.UUID") + # Python 3.0(.1) returns a bytearray instance for bytes (3.1 and + # newer just return a bytes instance). Convert that to binary_type + # for compatibility. + self = Binary.__new__(cls, binary_type(obj.bytes), OLD_UUID_SUBTYPE) + self.__uuid = obj + return self + + def __getnewargs__(self): + # Support copy and deepcopy + return (self.__uuid,) + + @property + def uuid(self): + """UUID instance wrapped by this UUIDLegacy instance. + """ + return self.__uuid + + def __repr__(self): + return "UUIDLegacy('%s')" % self.__uuid diff --git a/asyncio_mongo/_bson/code.py b/asyncio_mongo/_bson/code.py new file mode 100644 index 0000000..194a78a --- /dev/null +++ b/asyncio_mongo/_bson/code.py @@ -0,0 +1,78 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing JavaScript code in BSON. +""" + +class Code(str): + """BSON's JavaScript code type. + + Raises :class:`TypeError` if `code` is not an instance of + :class:`basestring` (:class:`str` in python 3) or `scope` + is not ``None`` or an instance of :class:`dict`. + + Scope variables can be set by passing a dictionary as the `scope` + argument or by using keyword arguments. If a variable is set as a + keyword argument it will override any setting for that variable in + the `scope` dictionary. + + :Parameters: + - `code`: string containing JavaScript code to be evaluated + - `scope` (optional): dictionary representing the scope in which + `code` should be evaluated - a mapping from identifiers (as + strings) to values + - `**kwargs` (optional): scope variables can also be passed as + keyword arguments + + .. versionadded:: 1.9 + Ability to pass scope values using keyword arguments. + """ + + def __new__(cls, code, scope=None, **kwargs): + if not isinstance(code, str): + raise TypeError("code must be an " + "instance of %s" % (str.__name__,)) + + self = str.__new__(cls, code) + + try: + self.__scope = code.scope + except AttributeError: + self.__scope = {} + + if scope is not None: + if not isinstance(scope, dict): + raise TypeError("scope must be an instance of dict") + self.__scope.update(scope) + + self.__scope.update(kwargs) + + return self + + @property + def scope(self): + """Scope dictionary for this instance. + """ + return self.__scope + + def __repr__(self): + return "Code(%s, %r)" % (str.__repr__(self), self.__scope) + + def __eq__(self, other): + if isinstance(other, Code): + return (self.__scope, str(self)) == (other.__scope, str(other)) + return False + + def __ne__(self, other): + return not self == other diff --git a/asyncio_mongo/_bson/dbref.py b/asyncio_mongo/_bson/dbref.py new file mode 100644 index 0000000..03b0ea0 --- /dev/null +++ b/asyncio_mongo/_bson/dbref.py @@ -0,0 +1,144 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for manipulating DBRefs (references to MongoDB documents).""" + +from copy import deepcopy + +from asyncio_mongo._bson.son import SON + + +class DBRef(object): + """A reference to a document stored in MongoDB. + """ + + def __init__(self, collection, id, database=None, _extra={}, **kwargs): + """Initialize a new :class:`DBRef`. + + Raises :class:`TypeError` if `collection` or `database` is not + an instance of :class:`basestring` (:class:`str` in python 3). + `database` is optional and allows references to documents to work + across databases. Any additional keyword arguments will create + additional fields in the resultant embedded document. + + :Parameters: + - `collection`: name of the collection the document is stored in + - `id`: the value of the document's ``"_id"`` field + - `database` (optional): name of the database to reference + - `**kwargs` (optional): additional keyword arguments will + create additional, custom fields + + .. versionchanged:: 1.8 + Now takes keyword arguments to specify additional fields. + .. versionadded:: 1.1.1 + The `database` parameter. + + .. mongodoc:: dbrefs + """ + if not isinstance(collection, str): + raise TypeError("collection must be an " + "instance of %s" % (str.__name__,)) + if database is not None and not isinstance(database, str): + raise TypeError("database must be an " + "instance of %s" % (str.__name__,)) + + self.__collection = collection + self.__id = id + self.__database = database + kwargs.update(_extra) + self.__kwargs = kwargs + + @property + def collection(self): + """Get the name of this DBRef's collection as unicode. + """ + return self.__collection + + @property + def id(self): + """Get this DBRef's _id. + """ + return self.__id + + @property + def database(self): + """Get the name of this DBRef's database. + + Returns None if this DBRef doesn't specify a database. + + .. versionadded:: 1.1.1 + """ + return self.__database + + def __getattr__(self, key): + try: + return self.__kwargs[key] + except KeyError: + raise AttributeError(key) + + # Have to provide __setstate__ to avoid + # infinite recursion since we override + # __getattr__. + def __setstate__(self, state): + self.__dict__.update(state) + + def as_doc(self): + """Get the SON document representation of this DBRef. + + Generally not needed by application developers + """ + doc = SON([("$ref", self.collection), + ("$id", self.id)]) + if self.database is not None: + doc["$db"] = self.database + doc.update(self.__kwargs) + return doc + + def __repr__(self): + extra = "".join([", %s=%r" % (k, v) + for k, v in self.__kwargs.items()]) + if self.database is None: + return "DBRef(%r, %r%s)" % (self.collection, self.id, extra) + return "DBRef(%r, %r, %r%s)" % (self.collection, self.id, + self.database, extra) + + def __eq__(self, other): + if isinstance(other, DBRef): + us = (self.__database, self.__collection, + self.__id, self.__kwargs) + them = (other.__database, other.__collection, + other.__id, other.__kwargs) + return us == them + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __hash__(self): + """Get a hash value for this :class:`DBRef`. + + .. versionadded:: 1.1 + """ + return hash((self.__collection, self.__id, self.__database, + tuple(sorted(self.__kwargs.items())))) + + def __deepcopy__(self, memo): + """Support function for `copy.deepcopy()`. + + .. versionadded:: 1.10 + """ + return DBRef(deepcopy(self.__collection, memo), + deepcopy(self.__id, memo), + deepcopy(self.__database, memo), + deepcopy(self.__kwargs, memo)) diff --git a/asyncio_mongo/_bson/errors.py b/asyncio_mongo/_bson/errors.py new file mode 100644 index 0000000..9501bf4 --- /dev/null +++ b/asyncio_mongo/_bson/errors.py @@ -0,0 +1,40 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions raised by the BSON package.""" + + +class BSONError(Exception): + """Base class for all BSON exceptions. + """ + + +class InvalidBSON(BSONError): + """Raised when trying to create a BSON object from invalid data. + """ + + +class InvalidStringData(BSONError): + """Raised when trying to encode a string containing non-UTF8 data. + """ + + +class InvalidDocument(BSONError): + """Raised when trying to create a BSON object from an invalid document. + """ + + +class InvalidId(BSONError): + """Raised when trying to create an ObjectId from invalid data. + """ diff --git a/asyncio_mongo/_bson/json_util.py b/asyncio_mongo/_bson/json_util.py new file mode 100644 index 0000000..2044a3c --- /dev/null +++ b/asyncio_mongo/_bson/json_util.py @@ -0,0 +1,220 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for using Python's :mod:`json` module with BSON documents. + +This module provides two helper methods `dumps` and `loads` that wrap the +native :mod:`json` methods and provide explicit BSON conversion to and from +json. This allows for specialized encoding and decoding of BSON documents +into `Mongo Extended JSON +`_'s *Strict* +mode. This lets you encode / decode BSON documents to JSON even when +they use special BSON types. + +Example usage (serialization):: + +.. doctest:: + + >>> from asyncio_mongo._bson import Binary, Code + >>> from asyncio_mongo._bson.json_util import dumps + >>> dumps([{'foo': [1, 2]}, + ... {'bar': {'hello': 'world'}}, + ... {'code': Code("function x() { return 1; }")}, + ... {'bin': Binary("\x00\x01\x02\x03\x04")}]) + '[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$scope": {}, "$code": "function x() { return 1; }"}}, {"bin": {"$type": "00", "$binary": "AAECAwQ="}}]' + +Example usage (deserialization):: + +.. doctest:: + + >>> from asyncio_mongo._bson.json_util import loads + >>> loads('[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$scope": {}, "$code": "function x() { return 1; }"}}, {"bin": {"$type": "00", "$binary": "AAECAwQ="}}]') + [{u'foo': [1, 2]}, {u'bar': {u'hello': u'world'}}, {u'code': Code('function x() { return 1; }', {})}, {u'bin': Binary('\x00\x01\x02\x03\x04', 0)}] + +Alternatively, you can manually pass the `default` to :func:`json.dumps`. +It won't handle :class:`~bson.binary.Binary` and :class:`~bson.code.Code` +instances (as they are extended strings you can't provide custom defaults), +but it will be faster as there is less recursion. + +.. versionchanged:: 2.3 + Added dumps and loads helpers to automatically handle conversion to and + from json and supports :class:`~bson.binary.Binary` and + :class:`~bson.code.Code` + +.. versionchanged:: 1.9 + Handle :class:`uuid.UUID` instances, whenever possible. + +.. versionchanged:: 1.8 + Handle timezone aware datetime instances on encode, decode to + timezone aware datetime instances. + +.. versionchanged:: 1.8 + Added support for encoding/decoding :class:`~bson.max_key.MaxKey` + and :class:`~bson.min_key.MinKey`, and for encoding + :class:`~bson.timestamp.Timestamp`. + +.. versionchanged:: 1.2 + Added support for encoding/decoding datetimes and regular expressions. +""" + +import base64 +import calendar +import datetime +import re + +json_lib = True +try: + import json +except ImportError: + try: + import simplejson as json + except ImportError: + json_lib = False + +import asyncio_mongo._bson as bson +from asyncio_mongo._bson import EPOCH_AWARE, RE_TYPE +from asyncio_mongo._bson.binary import Binary +from asyncio_mongo._bson.code import Code +from asyncio_mongo._bson.dbref import DBRef +from asyncio_mongo._bson.max_key import MaxKey +from asyncio_mongo._bson.min_key import MinKey +from asyncio_mongo._bson.objectid import ObjectId +from asyncio_mongo._bson.timestamp import Timestamp + +from asyncio_mongo._bson.py3compat import PY3, binary_type, string_types + + +_RE_OPT_TABLE = { + "i": re.I, + "l": re.L, + "m": re.M, + "s": re.S, + "u": re.U, + "x": re.X, +} + + +def dumps(obj, *args, **kwargs): + """Helper function that wraps :class:`json.dumps`. + + Recursive function that handles all BSON types including + :class:`~bson.binary.Binary` and :class:`~bson.code.Code`. + """ + if not json_lib: + raise Exception("No json library available") + return json.dumps(_json_convert(obj), *args, **kwargs) + + +def loads(s, *args, **kwargs): + """Helper function that wraps :class:`json.loads`. + + Automatically passes the object_hook for BSON type conversion. + """ + if not json_lib: + raise Exception("No json library available") + kwargs['object_hook'] = object_hook + return json.loads(s, *args, **kwargs) + + +def _json_convert(obj): + """Recursive helper method that converts BSON types so they can be + converted into json. + """ + if hasattr(obj, 'iteritems') or hasattr(obj, 'items'): # PY3 support + return dict(((k, _json_convert(v)) for k, v in obj.items())) + elif hasattr(obj, '__iter__') and not isinstance(obj, string_types): + return list((_json_convert(v) for v in obj)) + try: + return default(obj) + except TypeError: + return obj + + +def object_hook(dct): + if "$oid" in dct: + return ObjectId(str(dct["$oid"])) + if "$ref" in dct: + return DBRef(dct["$ref"], dct["$id"], dct.get("$db", None)) + if "$date" in dct: + secs = float(dct["$date"]) / 1000.0 + return EPOCH_AWARE + datetime.timedelta(seconds=secs) + if "$regex" in dct: + flags = 0 + # PyMongo always adds $options but some other tools may not. + for opt in dct.get("$options", ""): + flags |= _RE_OPT_TABLE.get(opt, 0) + return re.compile(dct["$regex"], flags) + if "$minKey" in dct: + return MinKey() + if "$maxKey" in dct: + return MaxKey() + if "$binary" in dct: + if isinstance(dct["$type"], int): + dct["$type"] = "%02x" % dct["$type"] + subtype = int(dct["$type"], 16) + if subtype >= 0xffffff80: # Handle mongoexport values + subtype = int(dct["$type"][6:], 16) + return Binary(base64.b64decode(dct["$binary"].encode()), subtype) + if "$code" in dct: + return Code(dct["$code"], dct.get("$scope")) + if bson.has_uuid() and "$uuid" in dct: + return bson.uuid.UUID(dct["$uuid"]) + return dct + + +def default(obj): + if isinstance(obj, ObjectId): + return {"$oid": str(obj)} + if isinstance(obj, DBRef): + return _json_convert(obj.as_doc()) + if isinstance(obj, datetime.datetime): + # TODO share this code w/ bson.py? + if obj.utcoffset() is not None: + obj = obj - obj.utcoffset() + millis = int(calendar.timegm(obj.timetuple()) * 1000 + + obj.microsecond / 1000) + return {"$date": millis} + if isinstance(obj, RE_TYPE): + flags = "" + if obj.flags & re.IGNORECASE: + flags += "i" + if obj.flags & re.LOCALE: + flags += "l" + if obj.flags & re.MULTILINE: + flags += "m" + if obj.flags & re.DOTALL: + flags += "s" + if obj.flags & re.UNICODE: + flags += "u" + if obj.flags & re.VERBOSE: + flags += "x" + return {"$regex": obj.pattern, + "$options": flags} + if isinstance(obj, MinKey): + return {"$minKey": 1} + if isinstance(obj, MaxKey): + return {"$maxKey": 1} + if isinstance(obj, Timestamp): + return {"t": obj.time, "i": obj.inc} + if isinstance(obj, Code): + return {'$code': "%s" % obj, '$scope': obj.scope} + if isinstance(obj, Binary): + return {'$binary': base64.b64encode(obj).decode(), + '$type': "%02x" % obj.subtype} + if PY3 and isinstance(obj, binary_type): + return {'$binary': base64.b64encode(obj).decode(), + '$type': "00"} + if bson.has_uuid() and isinstance(obj, bson.uuid.UUID): + return {"$uuid": obj.hex} + raise TypeError("%r is not JSON serializable" % obj) diff --git a/asyncio_mongo/_bson/max_key.py b/asyncio_mongo/_bson/max_key.py new file mode 100644 index 0000000..a758535 --- /dev/null +++ b/asyncio_mongo/_bson/max_key.py @@ -0,0 +1,32 @@ +# Copyright 2010-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Representation for the MongoDB internal MaxKey type. +""" + + +class MaxKey(object): + """MongoDB internal MaxKey type. + """ + + def __eq__(self, other): + if isinstance(other, MaxKey): + return True + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "MaxKey()" diff --git a/asyncio_mongo/_bson/min_key.py b/asyncio_mongo/_bson/min_key.py new file mode 100644 index 0000000..9047128 --- /dev/null +++ b/asyncio_mongo/_bson/min_key.py @@ -0,0 +1,32 @@ +# Copyright 2010-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Representation for the MongoDB internal MinKey type. +""" + + +class MinKey(object): + """MongoDB internal MinKey type. + """ + + def __eq__(self, other): + if isinstance(other, MinKey): + return True + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "MinKey()" diff --git a/asyncio_mongo/_bson/objectid.py b/asyncio_mongo/_bson/objectid.py new file mode 100644 index 0000000..400f204 --- /dev/null +++ b/asyncio_mongo/_bson/objectid.py @@ -0,0 +1,289 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for working with MongoDB `ObjectIds +`_. +""" + +import binascii +import calendar +import datetime +try: + import hashlib + _md5func = hashlib.md5 +except ImportError: # for Python < 2.5 + import md5 + _md5func = md5.new +import os +import random +import socket +import struct +import threading +import time + +from asyncio_mongo._bson.errors import InvalidId +from asyncio_mongo._bson.py3compat import (PY3, b, binary_type, text_type, + bytes_from_hex, string_types) +from asyncio_mongo._bson.tz_util import utc + +EMPTY = b("") +ZERO = b("\x00") + +def _machine_bytes(): + """Get the machine portion of an ObjectId. + """ + machine_hash = _md5func() + if PY3: + # gethostname() returns a unicode string in python 3.x + # while update() requires a byte string. + machine_hash.update(socket.gethostname().encode()) + else: + # Calling encode() here will fail with non-ascii hostnames + machine_hash.update(socket.gethostname()) + return machine_hash.digest()[0:3] + + +class ObjectId(object): + """A MongoDB ObjectId. + """ + + _inc = random.randint(0, 0xFFFFFF) + _inc_lock = threading.Lock() + + _machine_bytes = _machine_bytes() + + __slots__ = ('__id') + + def __init__(self, oid=None): + """Initialize a new ObjectId. + + If `oid` is ``None``, create a new (unique) ObjectId. If `oid` + is an instance of (:class:`basestring` (:class:`str` or :class:`bytes` + in python 3), :class:`ObjectId`) validate it and use that. Otherwise, + a :class:`TypeError` is raised. If `oid` is invalid, + :class:`~bson.errors.InvalidId` is raised. + + :Parameters: + - `oid` (optional): a valid ObjectId (12 byte binary or 24 character + hex string) + + .. versionadded:: 1.2.1 + The `oid` parameter can be a ``unicode`` instance (that contains + only hexadecimal digits). + + .. mongodoc:: objectids + """ + if oid is None: + self.__generate() + else: + self.__validate(oid) + + @classmethod + def from_datetime(cls, generation_time): + """Create a dummy ObjectId instance with a specific generation time. + + This method is useful for doing range queries on a field + containing :class:`ObjectId` instances. + + .. warning:: + It is not safe to insert a document containing an ObjectId + generated using this method. This method deliberately + eliminates the uniqueness guarantee that ObjectIds + generally provide. ObjectIds generated with this method + should be used exclusively in queries. + + `generation_time` will be converted to UTC. Naive datetime + instances will be treated as though they already contain UTC. + + An example using this helper to get documents where ``"_id"`` + was generated before January 1, 2010 would be: + + >>> gen_time = datetime.datetime(2010, 1, 1) + >>> dummy_id = ObjectId.from_datetime(gen_time) + >>> result = collection.find({"_id": {"$lt": dummy_id}}) + + :Parameters: + - `generation_time`: :class:`~datetime.datetime` to be used + as the generation time for the resulting ObjectId. + + .. versionchanged:: 1.8 + Properly handle timezone aware values for + `generation_time`. + + .. versionadded:: 1.6 + """ + if generation_time.utcoffset() is not None: + generation_time = generation_time - generation_time.utcoffset() + ts = calendar.timegm(generation_time.timetuple()) + oid = struct.pack(">i", int(ts)) + ZERO * 8 + return cls(oid) + + @classmethod + def is_valid(cls, oid): + """Checks if a `oid` string is valid or not. + + :Parameters: + - `oid`: the object id to validate + + .. versionadded:: 2.3 + """ + try: + ObjectId(oid) + return True + except (InvalidId, TypeError): + return False + + def __generate(self): + """Generate a new value for this ObjectId. + """ + oid = EMPTY + + # 4 bytes current time + oid += struct.pack(">i", int(time.time())) + + # 3 bytes machine + oid += ObjectId._machine_bytes + + # 2 bytes pid + oid += struct.pack(">H", os.getpid() % 0xFFFF) + + # 3 bytes inc + ObjectId._inc_lock.acquire() + oid += struct.pack(">i", ObjectId._inc)[1:4] + ObjectId._inc = (ObjectId._inc + 1) % 0xFFFFFF + ObjectId._inc_lock.release() + + self.__id = oid + + def __validate(self, oid): + """Validate and use the given id for this ObjectId. + + Raises TypeError if id is not an instance of + (:class:`basestring` (:class:`str` or :class:`bytes` + in python 3), ObjectId) and InvalidId if it is not a + valid ObjectId. + + :Parameters: + - `oid`: a valid ObjectId + """ + if isinstance(oid, ObjectId): + self.__id = oid.__id + elif isinstance(oid, string_types): + if len(oid) == 12: + if isinstance(oid, binary_type): + self.__id = oid + else: + raise InvalidId("%s is not a valid ObjectId" % oid) + elif len(oid) == 24: + try: + self.__id = bytes_from_hex(oid) + except (TypeError, ValueError): + raise InvalidId("%s is not a valid ObjectId" % oid) + else: + raise InvalidId("%s is not a valid ObjectId" % oid) + else: + raise TypeError("id must be an instance of (%s, %s, ObjectId), " + "not %s" % (binary_type.__name__, + text_type.__name__, type(oid))) + + @property + def binary(self): + """12-byte binary representation of this ObjectId. + """ + return self.__id + + @property + def generation_time(self): + """A :class:`datetime.datetime` instance representing the time of + generation for this :class:`ObjectId`. + + The :class:`datetime.datetime` is timezone aware, and + represents the generation time in UTC. It is precise to the + second. + + .. versionchanged:: 1.8 + Now return an aware datetime instead of a naive one. + + .. versionadded:: 1.2 + """ + t = struct.unpack(">i", self.__id[0:4])[0] + return datetime.datetime.fromtimestamp(t, utc) + + def __getstate__(self): + """return value of object for pickling. + needed explicitly because __slots__() defined. + """ + return self.__id + + def __setstate__(self, value): + """explicit state set from pickling + """ + # Provide backwards compatability with OIDs + # pickled with pymongo-1.9 or older. + if isinstance(value, dict): + oid = value["_ObjectId__id"] + else: + oid = value + # ObjectIds pickled in python 2.x used `str` for __id. + # In python 3.x this has to be converted to `bytes` + # by encoding latin-1. + if PY3 and isinstance(oid, text_type): + self.__id = oid.encode('latin-1') + else: + self.__id = oid + + def __str__(self): + if PY3: + return binascii.hexlify(self.__id).decode() + return binascii.hexlify(self.__id) + + def __repr__(self): + return "ObjectId('%s')" % (str(self),) + + def __eq__(self, other): + if isinstance(other, ObjectId): + return self.__id == other.__id + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ObjectId): + return self.__id != other.__id + return NotImplemented + + def __lt__(self, other): + if isinstance(other, ObjectId): + return self.__id < other.__id + return NotImplemented + + def __le__(self, other): + if isinstance(other, ObjectId): + return self.__id <= other.__id + return NotImplemented + + def __gt__(self, other): + if isinstance(other, ObjectId): + return self.__id > other.__id + return NotImplemented + + def __ge__(self, other): + if isinstance(other, ObjectId): + return self.__id >= other.__id + return NotImplemented + + def __hash__(self): + """Get a hash value for this :class:`ObjectId`. + + .. versionadded:: 1.1 + """ + return hash(self.__id) diff --git a/asyncio_mongo/_bson/py3compat.py b/asyncio_mongo/_bson/py3compat.py new file mode 100644 index 0000000..24a2e1c --- /dev/null +++ b/asyncio_mongo/_bson/py3compat.py @@ -0,0 +1,60 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Utility functions and definitions for python3 compatibility.""" + +import sys + +PY3 = sys.version_info[0] == 3 + +if PY3: + import codecs + + from io import BytesIO as StringIO + + def b(s): + # BSON and socket operations deal in binary data. In + # python 3 that means instances of `bytes`. In python + # 2.6 and 2.7 you can create an alias for `bytes` using + # the b prefix (e.g. b'foo'). Python 2.4 and 2.5 don't + # provide this marker so we provide this compat function. + # In python 3.x b('foo') results in b'foo'. + # See http://python3porting.com/problems.html#nicer-solutions + return codecs.latin_1_encode(s)[0] + + def bytes_from_hex(h): + return bytes.fromhex(h) + + binary_type = bytes + text_type = str + +else: + try: + from io import StringIO + except ImportError: + from io import StringIO + + def b(s): + # See comments above. In python 2.x b('foo') is just 'foo'. + return s + + def bytes_from_hex(h): + return h.decode('hex') + + binary_type = str + # 2to3 will convert this to "str". That's okay + # since we won't ever get here under python3. + text_type = str + +string_types = (binary_type, text_type) diff --git a/asyncio_mongo/_bson/son.py b/asyncio_mongo/_bson/son.py new file mode 100644 index 0000000..57b0b76 --- /dev/null +++ b/asyncio_mongo/_bson/son.py @@ -0,0 +1,243 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for creating and manipulating SON, the Serialized Ocument Notation. + +Regular dictionaries can be used instead of SON objects, but not when the order +of keys is important. A SON object can be used just like a normal Python +dictionary.""" + +import copy +import re + +# This sort of sucks, but seems to be as good as it gets... +# This is essentially the same as re._pattern_type +RE_TYPE = type(re.compile("")) + + +class SON(dict): + """SON data. + + A subclass of dict that maintains ordering of keys and provides a + few extra niceties for dealing with SON. SON objects can be + converted to and from asyncio_mongo._bson. + + The mapping from Python types to BSON types is as follows: + + =================================== ============= =================== + Python Type BSON Type Supported Direction + =================================== ============= =================== + None null both + bool boolean both + int [#int]_ int32 / int64 py -> bson + long int64 both + float number (real) both + string string py -> bson + unicode string both + list array both + dict / `SON` object both + datetime.datetime [#dt]_ [#dt2]_ date both + compiled re regex both + `bson.binary.Binary` binary both + `bson.objectid.ObjectId` oid both + `bson.dbref.DBRef` dbref both + None undefined bson -> py + unicode code bson -> py + `bson.code.Code` code py -> bson + unicode symbol bson -> py + bytes (Python 3) [#bytes]_ binary both + =================================== ============= =================== + + Note that to save binary data it must be wrapped as an instance of + `bson.binary.Binary`. Otherwise it will be saved as a BSON string + and retrieved as unicode. + + .. [#int] A Python int will be saved as a BSON int32 or BSON int64 depending + on its size. A BSON int32 will always decode to a Python int. In Python 2.x + a BSON int64 will always decode to a Python long. In Python 3.x a BSON + int64 will decode to a Python int since there is no longer a long type. + .. [#dt] datetime.datetime instances will be rounded to the nearest + millisecond when saved + .. [#dt2] all datetime.datetime instances are treated as *naive*. clients + should always use UTC. + .. [#bytes] The bytes type from Python 3.x is encoded as BSON binary with + subtype 0. In Python 3.x it will be decoded back to bytes. In Python 2.x + it will be decoded to an instance of :class:`~bson.binary.Binary` with + subtype 0. + """ + + def __init__(self, data=None, **kwargs): + self.__keys = [] + dict.__init__(self) + self.update(data) + self.update(kwargs) + + def __new__(cls, *args, **kwargs): + instance = super(SON, cls).__new__(cls, *args, **kwargs) + instance.__keys = [] + return instance + + def __repr__(self): + result = [] + for key in self.__keys: + result.append("(%r, %r)" % (key, self[key])) + return "SON([%s])" % ", ".join(result) + + def __setitem__(self, key, value): + if key not in self: + self.__keys.append(key) + dict.__setitem__(self, key, value) + + def __delitem__(self, key): + self.__keys.remove(key) + dict.__delitem__(self, key) + + def keys(self): + return list(self.__keys) + + def copy(self): + other = SON() + other.update(self) + return other + + # TODO this is all from UserDict.DictMixin. it could probably be made more + # efficient. + # second level definitions support higher levels + def __iter__(self): + for k in list(self.keys()): + yield k + + def has_key(self, key): + return key in list(self.keys()) + + def __contains__(self, key): + return key in list(self.keys()) + + # third level takes advantage of second level definitions + def iteritems(self): + for k in self: + yield (k, self[k]) + + def iterkeys(self): + return self.__iter__() + + # fourth level uses definitions from lower levels + def itervalues(self): + for _, v in self.items(): + yield v + + def values(self): + return [v for _, v in self.items()] + + def items(self): + return [(key, self[key]) for key in self] + + def clear(self): + for key in list(self.keys()): + del self[key] + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def pop(self, key, *args): + if len(args) > 1: + raise TypeError("pop expected at most 2 arguments, got "\ + + repr(1 + len(args))) + try: + value = self[key] + except KeyError: + if args: + return args[0] + raise + del self[key] + return value + + def popitem(self): + try: + k, v = next(iter(self.items())) + except StopIteration: + raise KeyError('container is empty') + del self[k] + return (k, v) + + def update(self, other=None, **kwargs): + # Make progressively weaker assumptions about "other" + if other is None: + pass + elif hasattr(other, 'iteritems'): # iteritems saves memory and lookups + for k, v in other.items(): + self[k] = v + elif hasattr(other, 'keys'): + for k in list(other.keys()): + self[k] = other[k] + else: + for k, v in other: + self[k] = v + if kwargs: + self.update(kwargs) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __eq__(self, other): + """Comparison to another SON is order-sensitive while comparison to a + regular dictionary is order-insensitive. + """ + if isinstance(other, SON): + return len(self) == len(other) and list(self.items()) == list(other.items()) + return self.to_dict() == other + + def __ne__(self, other): + return not self == other + + def __len__(self): + return len(list(self.keys())) + + def to_dict(self): + """Convert a SON document to a normal Python dictionary instance. + + This is trickier than just *dict(...)* because it needs to be + recursive. + """ + + def transform_value(value): + if isinstance(value, list): + return [transform_value(v) for v in value] + if isinstance(value, SON): + value = dict(value) + if isinstance(value, dict): + for k, v in value.items(): + value[k] = transform_value(v) + return value + + return transform_value(dict(self)) + + def __deepcopy__(self, memo): + out = SON() + val_id = id(self) + if val_id in memo: + return memo.get(val_id) + memo[val_id] = out + for k, v in self.items(): + if not isinstance(v, RE_TYPE): + v = copy.deepcopy(v, memo) + out[k] = v + return out diff --git a/asyncio_mongo/_bson/timestamp.py b/asyncio_mongo/_bson/timestamp.py new file mode 100644 index 0000000..1b0bd53 --- /dev/null +++ b/asyncio_mongo/_bson/timestamp.py @@ -0,0 +1,97 @@ +# Copyright 2010-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing MongoDB internal Timestamps. +""" + +import calendar +import datetime + +from asyncio_mongo._bson.tz_util import utc + +UPPERBOUND = 4294967296 + +class Timestamp(object): + """MongoDB internal timestamps used in the opLog. + """ + + def __init__(self, time, inc): + """Create a new :class:`Timestamp`. + + This class is only for use with the MongoDB opLog. If you need + to store a regular timestamp, please use a + :class:`~datetime.datetime`. + + Raises :class:`TypeError` if `time` is not an instance of + :class: `int` or :class:`~datetime.datetime`, or `inc` is not + an instance of :class:`int`. Raises :class:`ValueError` if + `time` or `inc` is not in [0, 2**32). + + :Parameters: + - `time`: time in seconds since epoch UTC, or a naive UTC + :class:`~datetime.datetime`, or an aware + :class:`~datetime.datetime` + - `inc`: the incrementing counter + + .. versionchanged:: 1.7 + `time` can now be a :class:`~datetime.datetime` instance. + """ + if isinstance(time, datetime.datetime): + if time.utcoffset() is not None: + time = time - time.utcoffset() + time = int(calendar.timegm(time.timetuple())) + if not isinstance(time, int): + raise TypeError("time must be an instance of int") + if not isinstance(inc, int): + raise TypeError("inc must be an instance of int") + if not 0 <= time < UPPERBOUND: + raise ValueError("time must be contained in [0, 2**32)") + if not 0 <= inc < UPPERBOUND: + raise ValueError("inc must be contained in [0, 2**32)") + + self.__time = time + self.__inc = inc + + @property + def time(self): + """Get the time portion of this :class:`Timestamp`. + """ + return self.__time + + @property + def inc(self): + """Get the inc portion of this :class:`Timestamp`. + """ + return self.__inc + + def __eq__(self, other): + if isinstance(other, Timestamp): + return (self.__time == other.time and self.__inc == other.inc) + else: + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "Timestamp(%s, %s)" % (self.__time, self.__inc) + + def as_datetime(self): + """Return a :class:`~datetime.datetime` instance corresponding + to the time portion of this :class:`Timestamp`. + + .. versionchanged:: 1.8 + The returned datetime is now timezone aware. + """ + return datetime.datetime.fromtimestamp(self.__time, utc) diff --git a/asyncio_mongo/_bson/tz_util.py b/asyncio_mongo/_bson/tz_util.py new file mode 100644 index 0000000..4437564 --- /dev/null +++ b/asyncio_mongo/_bson/tz_util.py @@ -0,0 +1,52 @@ +# Copyright 2010-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Timezone related utilities for BSON.""" + +from datetime import (timedelta, + tzinfo) + +ZERO = timedelta(0) + + +class FixedOffset(tzinfo): + """Fixed offset timezone, in minutes east from UTC. + + Implementation based from the Python `standard library documentation + `_. + Defining __getinitargs__ enables pickling / copying. + """ + + def __init__(self, offset, name): + if isinstance(offset, timedelta): + self.__offset = offset + else: + self.__offset = timedelta(minutes=offset) + self.__name = name + + def __getinitargs__(self): + return self.__offset, self.__name + + def utcoffset(self, dt): + return self.__offset + + def tzname(self, dt): + return self.__name + + def dst(self, dt): + return ZERO + + +utc = FixedOffset(0, "UTC") +"""Fixed offset timezone representing UTC.""" diff --git a/asyncio_mongo/_gridfs/__init__.py b/asyncio_mongo/_gridfs/__init__.py new file mode 100644 index 0000000..3ceb539 --- /dev/null +++ b/asyncio_mongo/_gridfs/__init__.py @@ -0,0 +1,198 @@ +# Copyright 2009-2010 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GridFS is a specification for storing large objects in Mongo. + +The :mod:`gridfs` package is an implementation of GridFS on top of +:mod:`pymongo`, exposing a file-like interface. + +.. mongodoc:: gridfs +""" +from twisted.python import log +from twisted.internet import defer +from txmongo._gridfs.errors import (NoFile, + UnsupportedAPI) +from txmongo._gridfs.grid_file import (GridIn, + GridOut) +from txmongo import filter +from txmongo.filter import (ASCENDING, + DESCENDING) +from txmongo.database import Database + + +class GridFS(object): + """An instance of GridFS on top of a single Database. + """ + def __init__(self, database, collection="fs"): + """Create a new instance of :class:`GridFS`. + + Raises :class:`TypeError` if `database` is not an instance of + :class:`~pymongo.database.Database`. + + :Parameters: + - `database`: database to use + - `collection` (optional): root collection to use + + .. versionadded:: 1.6 + The `collection` parameter. + + .. mongodoc:: gridfs + """ + if not isinstance(database, Database): + raise TypeError("database must be an instance of Database") + + self.__database = database + self.__collection = database[collection] + self.__files = self.__collection.files + self.__chunks = self.__collection.chunks + self.__chunks.create_index(filter.sort(ASCENDING("files_id") + ASCENDING("n")), + unique=True) + + def new_file(self, **kwargs): + """Create a new file in GridFS. + + Returns a new :class:`~gridfs.grid_file.GridIn` instance to + which data can be written. Any keyword arguments will be + passed through to :meth:`~gridfs.grid_file.GridIn`. + + :Parameters: + - `**kwargs` (optional): keyword arguments for file creation + + .. versionadded:: 1.6 + """ + return GridIn(self.__collection, **kwargs) + + def put(self, data, **kwargs): + """Put data in GridFS as a new file. + + Equivalent to doing: + + >>> f = new_file(**kwargs) + >>> try: + >>> f.write(data) + >>> finally: + >>> f.close() + + `data` can be either an instance of :class:`str` or a + file-like object providing a :meth:`read` method. Any keyword + arguments will be passed through to the created file - see + :meth:`~gridfs.grid_file.GridIn` for possible + arguments. Returns the ``"_id"`` of the created file. + + :Parameters: + - `data`: data to be written as a file. + - `**kwargs` (optional): keyword arguments for file creation + + .. versionadded:: 1.6 + """ + grid_file = GridIn(self.__collection, **kwargs) + try: + grid_file.write(data) + finally: + grid_file.close() + return grid_file._id + + def get(self, file_id): + """Get a file from GridFS by ``"_id"``. + + Returns an instance of :class:`~gridfs.grid_file.GridOut`, + which provides a file-like interface for reading. + + :Parameters: + - `file_id`: ``"_id"`` of the file to get + + .. versionadded:: 1.6 + """ + return GridOut(self.__collection, file_id) + + def get_last_version(self, filename): + """Get a file from GridFS by ``"filename"``. + + Returns the most recently uploaded file in GridFS with the + name `filename` as an instance of + :class:`~gridfs.grid_file.GridOut`. Raises + :class:`~gridfs.errors.NoFile` if no such file exists. + + An index on ``{filename: 1, uploadDate: -1}`` will + automatically be created when this method is called the first + time. + + :Parameters: + - `filename`: ``"filename"`` of the file to get + + .. versionadded:: 1.6 + """ + self.__files.ensure_index(filter.sort(ASCENDING("filename") + \ + DESCENDING("uploadDate"))) + + d = self.__files.find({"filename": filename}, + filter=filter.sort(DESCENDING('uploadDate'))) + d.addCallback(self._cb_get_last_version, filename) + return d +# cursor.limit(-1).sort("uploadDate", -1)#DESCENDING) + + def _cb_get_last_version(self, docs, filename): + try: + grid_file = docs[0] + return GridOut(self.__collection, grid_file) + except IndexError: + raise NoFile("no file in gridfs with filename %r" % filename) + + # TODO add optional safe mode for chunk removal? + def delete(self, file_id): + """Delete a file from GridFS by ``"_id"``. + + Removes all data belonging to the file with ``"_id"``: + `file_id`. + + .. warning:: Any processes/threads reading from the file while + this method is executing will likely see an invalid/corrupt + file. Care should be taken to avoid concurrent reads to a file + while it is being deleted. + + :Parameters: + - `file_id`: ``"_id"`` of the file to delete + + .. versionadded:: 1.6 + """ + dl = [] + dl.append(self.__files.remove({"_id": file_id}, safe=True)) + dl.append(self.__chunks.remove({"files_id": file_id})) + return defer.DeferredList(dl) + + def list(self): + """List the names of all files stored in this instance of + :class:`GridFS`. + + .. versionchanged:: 1.6 + Removed the `collection` argument. + """ + return self.__files.distinct("filename") + + def open(self, *args, **kwargs): + """No longer supported. + + .. versionchanged:: 1.6 + The open method is no longer supported. + """ + raise UnsupportedAPI("The open method is no longer supported.") + + def remove(self, *args, **kwargs): + """No longer supported. + + .. versionchanged:: 1.6 + The remove method is no longer supported. + """ + raise UnsupportedAPI("The remove method is no longer supported. " + "Please use the delete method instead.") diff --git a/asyncio_mongo/_gridfs/errors.py b/asyncio_mongo/_gridfs/errors.py new file mode 100644 index 0000000..a0ecb20 --- /dev/null +++ b/asyncio_mongo/_gridfs/errors.py @@ -0,0 +1,47 @@ +# Copyright 2009-2010 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions raised by the :mod:`gridfs` package""" + + +class GridFSError(Exception): + """Base class for all GridFS exceptions. + + .. versionadded:: 1.5 + """ + + +class CorruptGridFile(GridFSError): + """Raised when a file in :class:`~gridfs.GridFS` is malformed. + """ + + +class NoFile(GridFSError): + """Raised when trying to read from a non-existent file. + + .. versionadded:: 1.6 + """ + + +class UnsupportedAPI(GridFSError): + """Raised when trying to use the old GridFS API. + + In version 1.6 of the PyMongo distribution there were backwards + incompatible changes to the GridFS API. Upgrading shouldn't be + difficult, but the old API is no longer supported (with no + deprecation period). This exception will be raised when attempting + to use unsupported constructs from the old API. + + .. versionadded:: 1.6 + """ diff --git a/asyncio_mongo/_gridfs/grid_file.py b/asyncio_mongo/_gridfs/grid_file.py new file mode 100644 index 0000000..a056d2c --- /dev/null +++ b/asyncio_mongo/_gridfs/grid_file.py @@ -0,0 +1,444 @@ +# Copyright 2009-2010 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for representing files stored in GridFS.""" + +import datetime +import math +import os +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO + +from twisted.python import log +from twisted.internet import defer +from txmongo._gridfs.errors import (CorruptGridFile, + NoFile, + UnsupportedAPI) +from txmongo._pymongo.binary import Binary +from txmongo._pymongo.objectid import ObjectId +from txmongo.collection import Collection + +try: + _SEEK_SET = os.SEEK_SET + _SEEK_CUR = os.SEEK_CUR + _SEEK_END = os.SEEK_END +except AttributeError: # before 2.5 + _SEEK_SET = 0 + _SEEK_CUR = 1 + _SEEK_END = 2 + + +"""Default chunk size, in bytes.""" +DEFAULT_CHUNK_SIZE = 256 * 1024 + + +def _create_property(field_name, docstring, + read_only=False, closed_only=False): + """Helper for creating properties to read/write to files. + """ + def getter(self): + if closed_only and not self._closed: + raise AttributeError("can only get %r on a closed file" % + field_name) + return self._file.get(field_name, None) + + def setter(self, value): + if self._closed: + raise AttributeError("cannot set %r on a closed file" % + field_name) + self._file[field_name] = value + + if read_only: + docstring = docstring + "\n\nThis attribute is read-only." + elif not closed_only: + docstring = "%s\n\n%s" % (docstring, "This attribute can only be " + "set before :meth:`close` has been called.") + else: + docstring = "%s\n\n%s" % (docstring, "This attribute is read-only and " + "can only be read after :meth:`close` " + "has been called.") + + if not read_only and not closed_only: + return property(getter, setter, doc=docstring) + return property(getter, doc=docstring) + + +class GridIn(object): + """Class to write data to GridFS. + """ + def __init__(self, root_collection, **kwargs): + """Write a file to GridFS + + Application developers should generally not need to + instantiate this class directly - instead see the methods + provided by :class:`~gridfs.GridFS`. + + Raises :class:`TypeError` if `root_collection` is not an + instance of :class:`~txmongo._pymongo.collection.Collection`. + + Any of the file level options specified in the `GridFS Spec + `_ may be passed as + keyword arguments. Any additional keyword arguments will be + set as additional fields on the file document. Valid keyword + arguments include: + + - ``"_id"``: unique ID for this file (default: + :class:`~pymongo.objectid.ObjectId`) + + - ``"filename"``: human name for the file + + - ``"contentType"`` or ``"content_type"``: valid mime-type + for the file + + - ``"chunkSize"`` or ``"chunk_size"``: size of each of the + chunks, in bytes (default: 256 kb) + + :Parameters: + - `root_collection`: root collection to write to + - `**kwargs` (optional): file level options (see above) + """ + if not isinstance(root_collection, Collection): + raise TypeError("root_collection must be an instance of Collection") + + # Handle alternative naming + if "content_type" in kwargs: + kwargs["contentType"] = kwargs.pop("content_type") + if "chunk_size" in kwargs: + kwargs["chunkSize"] = kwargs.pop("chunk_size") + + # Defaults + kwargs["_id"] = kwargs.get("_id", ObjectId()) + kwargs["chunkSize"] = kwargs.get("chunkSize", DEFAULT_CHUNK_SIZE) + + object.__setattr__(self, "_coll", root_collection) + object.__setattr__(self, "_chunks", root_collection.chunks) + object.__setattr__(self, "_file", kwargs) + object.__setattr__(self, "_buffer", StringIO()) + object.__setattr__(self, "_position", 0) + object.__setattr__(self, "_chunk_number", 0) + object.__setattr__(self, "_closed", False) + + @property + def closed(self): + """Is this file closed? + """ + return self._closed + + _id = _create_property("_id", "The ``'_id'`` value for this file.", + read_only=True) + filename = _create_property("filename", "Name of this file.") + content_type = _create_property("contentType", "Mime-type for this file.") + length = _create_property("length", "Length (in bytes) of this file.", + closed_only=True) + chunk_size = _create_property("chunkSize", "Chunk size for this file.", + read_only=True) + upload_date = _create_property("uploadDate", + "Date that this file was uploaded.", + closed_only=True) + md5 = _create_property("md5", "MD5 of the contents of this file " + "(generated on the server).", + closed_only=True) + + def __getattr__(self, name): + if name in self._file: + return self._file[name] + raise AttributeError("GridIn object has no attribute '%s'" % name) + + def __setattr__(self, name, value): + if self._closed: + raise AttributeError("cannot set %r on a closed file" % name) + object.__setattr__(self, name, value) + + @defer.inlineCallbacks + def __flush_data(self, data): + """Flush `data` to a chunk. + """ + if data: + assert(len(data) <= self.chunk_size) + chunk = {"files_id": self._file["_id"], + "n": self._chunk_number, + "data": Binary(data)} + + # Continue writing after the insert completes (non-blocking) + yield self._chunks.insert(chunk) + self._chunk_number += 1 + self._position += len(data) + + @defer.inlineCallbacks + def __flush_buffer(self): + """Flush the buffer contents out to a chunk. + """ + yield self.__flush_data(self._buffer.getvalue()) + self._buffer.close() + self._buffer = StringIO() + + @defer.inlineCallbacks + def __flush(self): + """Flush the file to the database. + """ + yield self.__flush_buffer() + + md5 = yield self._coll.filemd5(self._id) + + self._file["md5"] = md5 + self._file["length"] = self._position + self._file["uploadDate"] = datetime.datetime.utcnow() + yield self._coll.files.insert(self._file) + + @defer.inlineCallbacks + def close(self): + """Flush the file and close it. + + A closed file cannot be written any more. Calling + :meth:`close` more than once is allowed. + """ + if not self._closed: + yield self.__flush() + self._closed = True + + # TODO should support writing unicode to a file. this means that files will + # need to have an encoding attribute. + def write(self, data): + """Write data to the file. There is no return value. + + `data` can be either a string of bytes or a file-like object + (implementing :meth:`read`). + + Due to buffering, the data may not actually be written to the + database until the :meth:`close` method is called. Raises + :class:`ValueError` if this file is already closed. Raises + :class:`TypeError` if `data` is not an instance of + :class:`str` or a file-like object. + + :Parameters: + - `data`: string of bytes or file-like object to be written + to the file + """ + if self._closed: + raise ValueError("cannot write to a closed file") + + #NC: Reverse the order of string and file-like from asyncio_mongo._pymongo 1.6. + # It is more likely to call write several times when writing + # strings than to write multiple file-like objects to a + # single concatenated file. + + try: # string + while data: + space = self.chunk_size - self._buffer.tell() + if len(data) <= space: + self._buffer.write(data) + break + else: + self._buffer.write(data[:space]) + self.__flush_buffer() + data = data[space:] + except AttributeError: + try: # file-like + if self._buffer.tell() > 0: + space = self.chunk_size - self._buffer.tell() + self._buffer.write(data.read(space)) + self.__flush_buffer() + to_write = data.read(self.chunk_size) + while to_write and len(to_write) == self.chunk_size: + self.__flush_data(to_write) + to_write = data.read(self.chunk_size) + self._buffer.write(to_write) + except AttributeError: + raise TypeError("can only write strings or file-like objects") + + def writelines(self, sequence): + """Write a sequence of strings to the file. + + Does not add separators. + """ + for line in sequence: + self.write(line) + + def __enter__(self): + """Support for the context manager protocol. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Support for the context manager protocol. + + Close the file and allow exceptions to propogate. + """ + self.close() + return False # untrue will propogate exceptions + + +class GridOut(object): + """Class to read data out of GridFS. + """ + def __init__(self, root_collection, doc): + """Read a file from GridFS + + Application developers should generally not need to + instantiate this class directly - instead see the methods + provided by :class:`~gridfs.GridFS`. + + Raises :class:`TypeError` if `root_collection` is not an instance of + :class:`~txmongo._pymongo.collection.Collection`. + + :Parameters: + - `root_collection`: root collection to read from + - `file_id`: value of ``"_id"`` for the file to read + """ + if not isinstance(root_collection, Collection): + raise TypeError("root_collection must be an instance of Collection") + + self.__chunks = root_collection.chunks + self._file = doc + self.__current_chunk = -1 + self.__buffer = '' + self.__position = 0 + + _id = _create_property("_id", "The ``'_id'`` value for this file.", True) + name = _create_property("filename", "Name of this file.", True) + content_type = _create_property("contentType", "Mime-type for this file.", + True) + length = _create_property("length", "Length (in bytes) of this file.", + True) + chunk_size = _create_property("chunkSize", "Chunk size for this file.", + True) + upload_date = _create_property("uploadDate", + "Date that this file was first uploaded.", + True) + aliases = _create_property("aliases", "List of aliases for this file.", + True) + metadata = _create_property("metadata", "Metadata attached to this file.", + True) + md5 = _create_property("md5", "MD5 of the contents of this file " + "(generated on the server).", True) + + def __getattr__(self, name): + if name in self._file: + return self._file[name] + raise AttributeError("GridOut object has no attribute '%s'" % name) + + @defer.inlineCallbacks + def read(self, size=-1): + """Read at most `size` bytes from the file (less if there + isn't enough data). + + The bytes are returned as an instance of :class:`str`. If + `size` is negative or omitted all data is read. + + :Parameters: + - `size` (optional): the number of bytes to read + """ + if size: + remainder = int(self.length) - self.__position + if size < 0 or size > remainder: + size = remainder + + data = self.__buffer + chunk_number = (len(data) + self.__position) / self.chunk_size + + while len(data) < size: + chunk = yield self.__chunks.find_one({"files_id": self._id, + "n": chunk_number}) + if not chunk: + raise CorruptGridFile("no chunk #%d" % chunk_number) + + if not data: + data += chunk["data"][self.__position % self.chunk_size:] + else: + data += chunk["data"] + + chunk_number += 1 + + self.__position += size + to_return = data[:size] + self.__buffer = data[size:] + defer.returnValue(to_return) + + def tell(self): + """Return the current position of this file. + """ + return self.__position + + def seek(self, pos, whence=_SEEK_SET): + """Set the current position of this file. + + :Parameters: + - `pos`: the position (or offset if using relative + positioning) to seek to + - `whence` (optional): where to seek + from. :attr:`os.SEEK_SET` (``0``) for absolute file + positioning, :attr:`os.SEEK_CUR` (``1``) to seek relative + to the current position, :attr:`os.SEEK_END` (``2``) to + seek relative to the file's end. + """ + if whence == _SEEK_SET: + new_pos = pos + elif whence == _SEEK_CUR: + new_pos = self.__position + pos + elif whence == _SEEK_END: + new_pos = int(self.length) + pos + else: + raise IOError(22, "Invalid value for `whence`") + + if new_pos < 0: + raise IOError(22, "Invalid value for `pos` - must be positive") + + self.__position = new_pos + + def close(self): + self.__buffer = '' + self.__current_chunk = -1 + + def __iter__(self): + """Deprecated.""" + raise UnsupportedAPI("Iterating is deprecated for iterated reading") + + def __repr__(self): + return str(self._file) + + +class GridOutIterator(object): + def __init__(self, grid_out, chunks): + self.__id = grid_out._id + self.__chunks = chunks + self.__current_chunk = 0 + self.__max_chunk = math.ceil(float(grid_out.length) / + grid_out.chunk_size) + + def __iter__(self): + return self + + @defer.inlineCallbacks + def next(self): + if self.__current_chunk >= self.__max_chunk: + raise StopIteration + chunk = yield self.__chunks.find_one({"files_id": self.__id, + "n": self.__current_chunk}) + if not chunk: + raise CorruptGridFile("no chunk #%d" % self.__current_chunk) + self.__current_chunk += 1 + defer.returnValue(str(chunk["data"])) + + +class GridFile(object): + """No longer supported. + + .. versionchanged:: 1.6 + The GridFile class is no longer supported. + """ + def __init__(self, *args, **kwargs): + raise UnsupportedAPI("The GridFile class is no longer supported. " + "Please use GridIn or GridOut instead.") diff --git a/asyncio_mongo/_pymongo/__init__.py b/asyncio_mongo/_pymongo/__init__.py new file mode 100644 index 0000000..2fb2c31 --- /dev/null +++ b/asyncio_mongo/_pymongo/__init__.py @@ -0,0 +1,95 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python driver for MongoDB.""" + + +ASCENDING = 1 +"""Ascending sort order.""" +DESCENDING = -1 +"""Descending sort order.""" + +GEO2D = "2d" +"""Index specifier for a 2-dimensional `geospatial index`_. + +.. versionadded:: 1.5.1 + +.. note:: Geo-spatial indexing requires server version **>= 1.3.3**. + +.. _geospatial index: http://docs.mongodb.org/manual/core/geospatial-indexes/ +""" + +GEOHAYSTACK = "geoHaystack" +"""Index specifier for a 2-dimensional `haystack index`_. + +.. versionadded:: 2.1 + +.. note:: Geo-spatial indexing requires server version **>= 1.5.6**. + +.. _haystack index: http://docs.mongodb.org/manual/core/geospatial-indexes/#haystack-indexes +""" + +GEOSPHERE = "2dsphere" +"""Index specifier for a `spherical geospatial index`_. + +.. versionadded:: 2.5 + +.. note:: 2dsphere indexing requires server version **>= 2.4.0**. + +.. _spherical geospatial index: http://docs.mongodb.org/manual/release-notes/2.4/#new-geospatial-indexes-with-geojson-and-improved-spherical-geometry +""" + +HASHED = "hashed" +"""Index specifier for a `hashed index`_. + +.. versionadded:: 2.5 + +.. note:: hashed indexing requires server version **>= 2.4.0**. + +.. _hashed index: http://docs.mongodb.org/manual/release-notes/2.4/#new-hashed-index-and-sharding-with-a-hashed-shard-key +""" + +OFF = 0 +"""No database profiling.""" +SLOW_ONLY = 1 +"""Only profile slow operations.""" +ALL = 2 +"""Profile all operations.""" + +version_tuple = (2, 6, 3) + +def get_version_string(): + if isinstance(version_tuple[-1], str): + return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] + return '.'.join(map(str, version_tuple)) + +version = get_version_string() +"""Current version of PyMongo.""" + +from asyncio_mongo._pymongo.connection import Connection +from asyncio_mongo._pymongo.mongo_client import MongoClient +from asyncio_mongo._pymongo.mongo_replica_set_client import MongoReplicaSetClient +from asyncio_mongo._pymongo.replica_set_connection import ReplicaSetConnection +from asyncio_mongo._pymongo.read_preferences import ReadPreference + +def has_c(): + """Is the C extension installed? + + .. versionadded:: 1.5 + """ + try: + from asyncio_mongo._pymongo import _cmessage + return True + except ImportError: + return False diff --git a/asyncio_mongo/_pymongo/_cmessage.so b/asyncio_mongo/_pymongo/_cmessage.so new file mode 100755 index 0000000..88c35b3 Binary files /dev/null and b/asyncio_mongo/_pymongo/_cmessage.so differ diff --git a/asyncio_mongo/_pymongo/auth.py b/asyncio_mongo/_pymongo/auth.py new file mode 100644 index 0000000..bb98c77 --- /dev/null +++ b/asyncio_mongo/_pymongo/auth.py @@ -0,0 +1,215 @@ +# Copyright 2013 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Authentication helpers.""" + +try: + import hashlib + _MD5 = hashlib.md5 +except ImportError: # for Python < 2.5 + import md5 + _MD5 = md5.new + +HAVE_KERBEROS = True +try: + import kerberos +except ImportError: + HAVE_KERBEROS = False + +from asyncio_mongo._bson.binary import Binary +from asyncio_mongo._bson.son import SON +from asyncio_mongo._pymongo.errors import ConfigurationError, OperationFailure + + +MECHANISMS = ('GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'PLAIN') +"""The authentication mechanisms supported by PyMongo.""" + + +def _build_credentials_tuple(mech, source, user, passwd, extra): + """Build and return a mechanism specific credentials tuple. + """ + if mech == 'GSSAPI': + gsn = extra.get('gssapiservicename', 'mongodb') + # No password, source is always $external. + return (mech, '$external', user, gsn) + elif mech == 'MONGODB-X509': + return (mech, '$external', user) + return (mech, source, user, passwd) + + +def _password_digest(username, password): + """Get a password digest to use for authentication. + """ + if not isinstance(password, str): + raise TypeError("password must be an instance " + "of %s" % (str.__name__,)) + if len(password) == 0: + raise TypeError("password can't be empty") + if not isinstance(username, str): + raise TypeError("username must be an instance " + "of %s" % (str.__name__,)) + + md5hash = _MD5() + data = "%s:mongo:%s" % (username, password) + md5hash.update(data.encode('utf-8')) + return str(md5hash.hexdigest()) + + +def _auth_key(nonce, username, password): + """Get an auth key to use for authentication. + """ + digest = _password_digest(username, password) + md5hash = _MD5() + data = "%s%s%s" % (nonce, str(username), digest) + md5hash.update(data.encode('utf-8')) + return str(md5hash.hexdigest()) + + +def _authenticate_gssapi(credentials, sock_info, cmd_func): + """Authenticate using GSSAPI. + """ + try: + dummy, username, gsn = credentials + # Starting here and continuing through the while loop below - establish + # the security context. See RFC 4752, Section 3.1, first paragraph. + result, ctx = kerberos.authGSSClientInit(gsn + '@' + sock_info.host, + kerberos.GSS_C_MUTUAL_FLAG) + if result != kerberos.AUTH_GSS_COMPLETE: + raise OperationFailure('Kerberos context failed to initialize.') + + try: + # pykerberos uses a weird mix of exceptions and return values + # to indicate errors. + # 0 == continue, 1 == complete, -1 == error + # Only authGSSClientStep can return 0. + if kerberos.authGSSClientStep(ctx, '') != 0: + raise OperationFailure('Unknown kerberos ' + 'failure in step function.') + + # Start a SASL conversation with mongod/s + # Note: pykerberos deals with base64 encoded byte strings. + # Since mongo accepts base64 strings as the payload we don't + # have to use bson.binary.Binary. + payload = kerberos.authGSSClientResponse(ctx) + cmd = SON([('saslStart', 1), + ('mechanism', 'GSSAPI'), + ('payload', payload), + ('autoAuthorize', 1)]) + response, _ = cmd_func(sock_info, '$external', cmd) + + # Limit how many times we loop to catch protocol / library issues + for _ in range(10): + result = kerberos.authGSSClientStep(ctx, + str(response['payload'])) + if result == -1: + raise OperationFailure('Unknown kerberos ' + 'failure in step function.') + + payload = kerberos.authGSSClientResponse(ctx) or '' + + cmd = SON([('saslContinue', 1), + ('conversationId', response['conversationId']), + ('payload', payload)]) + response, _ = cmd_func(sock_info, '$external', cmd) + + if result == kerberos.AUTH_GSS_COMPLETE: + break + else: + raise OperationFailure('Kerberos ' + 'authentication failed to complete.') + + # Once the security context is established actually authenticate. + # See RFC 4752, Section 3.1, last two paragraphs. + if kerberos.authGSSClientUnwrap(ctx, + str(response['payload'])) != 1: + raise OperationFailure('Unknown kerberos ' + 'failure during GSS_Unwrap step.') + + if kerberos.authGSSClientWrap(ctx, + kerberos.authGSSClientResponse(ctx), + username) != 1: + raise OperationFailure('Unknown kerberos ' + 'failure during GSS_Wrap step.') + + payload = kerberos.authGSSClientResponse(ctx) + cmd = SON([('saslContinue', 1), + ('conversationId', response['conversationId']), + ('payload', payload)]) + response, _ = cmd_func(sock_info, '$external', cmd) + + finally: + kerberos.authGSSClientClean(ctx) + + except kerberos.KrbError as exc: + raise OperationFailure(str(exc)) + + +def _authenticate_plain(credentials, sock_info, cmd_func): + """Authenticate using SASL PLAIN (RFC 4616) + """ + source, username, password = credentials + payload = ('\x00%s\x00%s' % (username, password)).encode('utf-8') + cmd = SON([('saslStart', 1), + ('mechanism', 'PLAIN'), + ('payload', Binary(payload)), + ('autoAuthorize', 1)]) + cmd_func(sock_info, source, cmd) + + +def _authenticate_x509(credentials, sock_info, cmd_func): + """Authenticate using MONGODB-X509. + """ + dummy, username = credentials + query = SON([('authenticate', 1), + ('mechanism', 'MONGODB-X509'), + ('user', username)]) + cmd_func(sock_info, '$external', query) + + +def _authenticate_mongo_cr(credentials, sock_info, cmd_func): + """Authenticate using MONGODB-CR. + """ + source, username, password = credentials + # Get a nonce + response, _ = cmd_func(sock_info, source, {'getnonce': 1}) + nonce = response['nonce'] + key = _auth_key(nonce, username, password) + + # Actually authenticate + query = SON([('authenticate', 1), + ('user', username), + ('nonce', nonce), + ('key', key)]) + cmd_func(sock_info, source, query) + + +_AUTH_MAP = { + 'GSSAPI': _authenticate_gssapi, + 'MONGODB-CR': _authenticate_mongo_cr, + 'MONGODB-X509': _authenticate_x509, + 'PLAIN': _authenticate_plain, +} + + +def authenticate(credentials, sock_info, cmd_func): + """Authenticate sock_info. + """ + mechanism = credentials[0] + if mechanism == 'GSSAPI': + if not HAVE_KERBEROS: + raise ConfigurationError('The "kerberos" module must be ' + 'installed to use GSSAPI authentication.') + auth_func = _AUTH_MAP.get(mechanism) + auth_func(credentials[1:], sock_info, cmd_func) + diff --git a/asyncio_mongo/_pymongo/collection.py b/asyncio_mongo/_pymongo/collection.py new file mode 100644 index 0000000..54a1f43 --- /dev/null +++ b/asyncio_mongo/_pymongo/collection.py @@ -0,0 +1,1489 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collection level utilities for Mongo.""" + +import warnings + +from asyncio_mongo._bson.binary import ALL_UUID_SUBTYPES, OLD_UUID_SUBTYPE +from asyncio_mongo._bson.code import Code +from asyncio_mongo._bson.son import SON +from asyncio_mongo._pymongo import (common, + helpers, + message) +from asyncio_mongo._pymongo.cursor import Cursor +from asyncio_mongo._pymongo.errors import ConfigurationError, InvalidName + + +try: + from collections import OrderedDict + ordered_types = (SON, OrderedDict) +except ImportError: + ordered_types = SON + + +def _gen_index_name(keys): + """Generate an index name from the set of fields it is over. + """ + return "_".join(["%s_%s" % item for item in keys]) + + +class Collection(common.BaseObject): + """A Mongo collection. + """ + + def __init__(self, database, name, create=False, **kwargs): + """Get / create a Mongo collection. + + Raises :class:`TypeError` if `name` is not an instance of + :class:`basestring` (:class:`str` in python 3). Raises + :class:`~pymongo.errors.InvalidName` if `name` is not a valid + collection name. Any additional keyword arguments will be used + as options passed to the create command. See + :meth:`~pymongo.database.Database.create_collection` for valid + options. + + If `create` is ``True`` or additional keyword arguments are + present a create command will be sent. Otherwise, a create + command will not be sent and the collection will be created + implicitly on first use. + + :Parameters: + - `database`: the database to get a collection from + - `name`: the name of the collection to get + - `create` (optional): if ``True``, force collection + creation even without options being set + - `**kwargs` (optional): additional keyword arguments will + be passed as options for the create collection command + + .. versionchanged:: 2.2 + Removed deprecated argument: options + + .. versionadded:: 2.1 + uuid_subtype attribute + + .. versionchanged:: 1.5 + deprecating `options` in favor of kwargs + + .. versionadded:: 1.5 + the `create` parameter + + .. mongodoc:: collections + """ + super(Collection, self).__init__( + slave_okay=database.slave_okay, + read_preference=database.read_preference, + tag_sets=database.tag_sets, + secondary_acceptable_latency_ms=( + database.secondary_acceptable_latency_ms), + safe=database.safe, + **database.write_concern) + + if not isinstance(name, str): + raise TypeError("name must be an instance " + "of %s" % (str.__name__,)) + + if not name or ".." in name: + raise InvalidName("collection names cannot be empty") + if "$" in name and not (name.startswith("oplog.$main") or + name.startswith("$cmd")): + raise InvalidName("collection names must not " + "contain '$': %r" % name) + if name[0] == "." or name[-1] == ".": + raise InvalidName("collection names must not start " + "or end with '.': %r" % name) + if "\x00" in name: + raise InvalidName("collection names must not contain the " + "null character") + + self.__database = database + self.__name = str(name) + self.__uuid_subtype = OLD_UUID_SUBTYPE + self.__full_name = "%s.%s" % (self.__database.name, self.__name) + if create or kwargs: + self.__create(kwargs) + + def __create(self, options): + """Sends a create command with the given options. + """ + + if options: + if "size" in options: + options["size"] = float(options["size"]) + self.__database.command("create", self.__name, **options) + else: + self.__database.command("create", self.__name) + + def __getattr__(self, name): + """Get a sub-collection of this collection by name. + + Raises InvalidName if an invalid collection name is used. + + :Parameters: + - `name`: the name of the collection to get + """ + return Collection(self.__database, "%s.%s" % (self.__name, name)) + + def __getitem__(self, name): + return self.__getattr__(name) + + def __repr__(self): + return "Collection(%r, %r)" % (self.__database, self.__name) + + def __eq__(self, other): + if isinstance(other, Collection): + us = (self.__database, self.__name) + them = (other.__database, other.__name) + return us == them + return NotImplemented + + def __ne__(self, other): + return not self == other + + @property + def full_name(self): + """The full name of this :class:`Collection`. + + The full name is of the form `database_name.collection_name`. + + .. versionchanged:: 1.3 + ``full_name`` is now a property rather than a method. + """ + return self.__full_name + + @property + def name(self): + """The name of this :class:`Collection`. + + .. versionchanged:: 1.3 + ``name`` is now a property rather than a method. + """ + return self.__name + + @property + def database(self): + """The :class:`~pymongo.database.Database` that this + :class:`Collection` is a part of. + + .. versionchanged:: 1.3 + ``database`` is now a property rather than a method. + """ + return self.__database + + def __get_uuid_subtype(self): + return self.__uuid_subtype + + def __set_uuid_subtype(self, subtype): + if subtype not in ALL_UUID_SUBTYPES: + raise ConfigurationError("Not a valid setting for uuid_subtype.") + self.__uuid_subtype = subtype + + uuid_subtype = property(__get_uuid_subtype, __set_uuid_subtype, + doc="""This attribute specifies which BSON Binary + subtype is used when storing UUIDs. Historically + UUIDs have been stored as BSON Binary subtype 3. + This attribute is used to switch to the newer BSON + binary subtype 4. It can also be used to force + legacy byte order and subtype compatibility with + the Java and C# drivers. See the + :mod:`bson.binary` module for all options.""") + + def save(self, to_save, manipulate=True, + safe=None, check_keys=True, **kwargs): + """Save a document in this collection. + + If `to_save` already has an ``"_id"`` then an :meth:`update` + (upsert) operation is performed and any existing document with + that ``"_id"`` is overwritten. Otherwise an :meth:`insert` + operation is performed. In this case if `manipulate` is ``True`` + an ``"_id"`` will be added to `to_save` and this method returns + the ``"_id"`` of the saved document. If `manipulate` is ``False`` + the ``"_id"`` will be added by the server but this method will + return ``None``. + + Raises :class:`TypeError` if `to_save` is not an instance of + :class:`dict`. + + Write concern options can be passed as keyword arguments, overriding + any global defaults. Valid options include w=, + wtimeout=, j=, or fsync=. See the parameter list below + for a detailed explanation of these options. + + By default an acknowledgment is requested from the server that the + save was successful, raising :class:`~pymongo.errors.OperationFailure` + if an error occurred. **Passing ``w=0`` disables write acknowledgement + and all other write concern options.** + + :Parameters: + - `to_save`: the document to be saved + - `manipulate` (optional): manipulate the document before + saving it? + - `safe` (optional): **DEPRECATED** - Use `w` instead. + - `check_keys` (optional): check if keys start with '$' or + contain '.', raising :class:`~pymongo.errors.InvalidName` + in either case. + - `w` (optional): (integer or string) If this is a replica set, write + operations will block until they have been replicated to the + specified number or tagged set of servers. `w=` always includes + the replica set primary (e.g. w=3 means write to the primary and wait + until replicated to **two** secondaries). **Passing w=0 disables + write acknowledgement and all other write concern options.** + - `wtimeout` (optional): (integer) Used in conjunction with `w`. + Specify a value in milliseconds to control how long to wait for + write propagation to complete. If replication does not complete in + the given timeframe, a timeout exception is raised. + - `j` (optional): If ``True`` block until write operations have been + committed to the journal. Ignored if the server is running without + journaling. + - `fsync` (optional): If ``True`` force the database to fsync all + files before returning. When used with `j` the server awaits the + next group commit before returning. + :Returns: + - The ``'_id'`` value of `to_save` or ``[None]`` if `manipulate` is + ``False`` and `to_save` has no '_id' field. + + .. versionadded:: 1.8 + Support for passing `getLastError` options as keyword + arguments. + + .. mongodoc:: insert + """ + if not isinstance(to_save, dict): + raise TypeError("cannot save object of type %s" % type(to_save)) + + if "_id" not in to_save: + return self.insert(to_save, manipulate, safe, check_keys, **kwargs) + else: + self.update({"_id": to_save["_id"]}, to_save, True, + manipulate, safe, check_keys=check_keys, **kwargs) + return to_save.get("_id", None) + + def insert(self, doc_or_docs, manipulate=True, + safe=None, check_keys=True, continue_on_error=False, **kwargs): + """Insert a document(s) into this collection. + + If `manipulate` is ``True``, the document(s) are manipulated using + any :class:`~pymongo.son_manipulator.SONManipulator` instances + that have been added to this :class:`~pymongo.database.Database`. + In this case an ``"_id"`` will be added if the document(s) does + not already contain one and the ``"id"`` (or list of ``"_id"`` + values for more than one document) will be returned. + If `manipulate` is ``False`` and the document(s) does not include + an ``"_id"`` one will be added by the server. The server + does not return the ``"_id"`` it created so ``None`` is returned. + + Write concern options can be passed as keyword arguments, overriding + any global defaults. Valid options include w=, + wtimeout=, j=, or fsync=. See the parameter list below + for a detailed explanation of these options. + + By default an acknowledgment is requested from the server that the + insert was successful, raising :class:`~pymongo.errors.OperationFailure` + if an error occurred. **Passing ``w=0`` disables write acknowledgement + and all other write concern options.** + + :Parameters: + - `doc_or_docs`: a document or list of documents to be + inserted + - `manipulate` (optional): If ``True`` manipulate the documents + before inserting. + - `safe` (optional): **DEPRECATED** - Use `w` instead. + - `check_keys` (optional): If ``True`` check if keys start with '$' + or contain '.', raising :class:`~pymongo.errors.InvalidName` in + either case. + - `continue_on_error` (optional): If ``True``, the database will not + stop processing a bulk insert if one fails (e.g. due to duplicate + IDs). This makes bulk insert behave similarly to a series of single + inserts, except lastError will be set if any insert fails, not just + the last one. If multiple errors occur, only the most recent will + be reported by :meth:`~pymongo.database.Database.error`. + - `w` (optional): (integer or string) If this is a replica set, write + operations will block until they have been replicated to the + specified number or tagged set of servers. `w=` always includes + the replica set primary (e.g. w=3 means write to the primary and wait + until replicated to **two** secondaries). **Passing w=0 disables + write acknowledgement and all other write concern options.** + - `wtimeout` (optional): (integer) Used in conjunction with `w`. + Specify a value in milliseconds to control how long to wait for + write propagation to complete. If replication does not complete in + the given timeframe, a timeout exception is raised. + - `j` (optional): If ``True`` block until write operations have been + committed to the journal. Ignored if the server is running without + journaling. + - `fsync` (optional): If ``True`` force the database to fsync all + files before returning. When used with `j` the server awaits the + next group commit before returning. + :Returns: + - The ``'_id'`` value (or list of '_id' values) of `doc_or_docs` or + ``[None]`` if manipulate is ``False`` and the documents passed + as `doc_or_docs` do not include an '_id' field. + + .. note:: `continue_on_error` requires server version **>= 1.9.1** + + .. versionadded:: 2.1 + Support for continue_on_error. + .. versionadded:: 1.8 + Support for passing `getLastError` options as keyword + arguments. + .. versionchanged:: 1.1 + Bulk insert works with any iterable + + .. mongodoc:: insert + """ + # Batch inserts require us to know the connected master's + # max_bson_size and max_message_size. We have to be connected + # to a master to know that. + self.database.connection._ensure_connected(True) + + docs = doc_or_docs + return_one = False + if isinstance(docs, dict): + return_one = True + docs = [docs] + + if manipulate: + docs = [self.__database._fix_incoming(doc, self) for doc in docs] + + safe, options = self._get_write_mode(safe, **kwargs) + message._do_batched_insert(self.__full_name, docs, + check_keys, safe, options, + continue_on_error, self.__uuid_subtype, + self.database.connection) + + ids = [doc.get("_id", None) for doc in docs] + if return_one: + return ids[0] + else: + return ids + + def update(self, spec, document, upsert=False, manipulate=False, + safe=None, multi=False, check_keys=True, **kwargs): + """Update a document(s) in this collection. + + Raises :class:`TypeError` if either `spec` or `document` is + not an instance of ``dict`` or `upsert` is not an instance of + ``bool``. + + Write concern options can be passed as keyword arguments, overriding + any global defaults. Valid options include w=, + wtimeout=, j=, or fsync=. See the parameter list below + for a detailed explanation of these options. + + By default an acknowledgment is requested from the server that the + update was successful, raising :class:`~pymongo.errors.OperationFailure` + if an error occurred. **Passing ``w=0`` disables write acknowledgement + and all other write concern options.** + + There are many useful `update modifiers`_ which can be used + when performing updates. For example, here we use the + ``"$set"`` modifier to modify some fields in a matching + document: + + .. doctest:: + + >>> db.test.insert({"x": "y", "a": "b"}) + ObjectId('...') + >>> list(db.test.find()) + [{u'a': u'b', u'x': u'y', u'_id': ObjectId('...')}] + >>> db.test.update({"x": "y"}, {"$set": {"a": "c"}}) + {...} + >>> list(db.test.find()) + [{u'a': u'c', u'x': u'y', u'_id': ObjectId('...')}] + + :Parameters: + - `spec`: a ``dict`` or :class:`~bson.son.SON` instance + specifying elements which must be present for a document + to be updated + - `document`: a ``dict`` or :class:`~bson.son.SON` + instance specifying the document to be used for the update + or (in the case of an upsert) insert - see docs on MongoDB + `update modifiers`_ + - `upsert` (optional): perform an upsert if ``True`` + - `manipulate` (optional): manipulate the document before + updating? If ``True`` all instances of + :mod:`~pymongo.son_manipulator.SONManipulator` added to + this :class:`~pymongo.database.Database` will be applied + to the document before performing the update. + - `check_keys` (optional): check if keys in `document` start + with '$' or contain '.', raising + :class:`~pymongo.errors.InvalidName`. Only applies to + document replacement, not modification through $ + operators. + - `safe` (optional): **DEPRECATED** - Use `w` instead. + - `multi` (optional): update all documents that match + `spec`, rather than just the first matching document. The + default value for `multi` is currently ``False``, but this + might eventually change to ``True``. It is recommended + that you specify this argument explicitly for all update + operations in order to prepare your code for that change. + - `w` (optional): (integer or string) If this is a replica set, write + operations will block until they have been replicated to the + specified number or tagged set of servers. `w=` always includes + the replica set primary (e.g. w=3 means write to the primary and wait + until replicated to **two** secondaries). **Passing w=0 disables + write acknowledgement and all other write concern options.** + - `wtimeout` (optional): (integer) Used in conjunction with `w`. + Specify a value in milliseconds to control how long to wait for + write propagation to complete. If replication does not complete in + the given timeframe, a timeout exception is raised. + - `j` (optional): If ``True`` block until write operations have been + committed to the journal. Ignored if the server is running without + journaling. + - `fsync` (optional): If ``True`` force the database to fsync all + files before returning. When used with `j` the server awaits the + next group commit before returning. + :Returns: + - A document (dict) describing the effect of the update or ``None`` + if write acknowledgement is disabled. + + .. versionadded:: 1.8 + Support for passing `getLastError` options as keyword + arguments. + .. versionchanged:: 1.4 + Return the response to *lastError* if `safe` is ``True``. + .. versionadded:: 1.1.1 + The `multi` parameter. + + .. _update modifiers: http://www.mongodb.org/display/DOCS/Updating + + .. mongodoc:: update + """ + if not isinstance(spec, dict): + raise TypeError("spec must be an instance of dict") + if not isinstance(document, dict): + raise TypeError("document must be an instance of dict") + if not isinstance(upsert, bool): + raise TypeError("upsert must be an instance of bool") + + if manipulate: + document = self.__database._fix_incoming(document, self) + + safe, options = self._get_write_mode(safe, **kwargs) + + if document: + # If a top level key begins with '$' this is a modify operation + # and we should skip key validation. It doesn't matter which key + # we check here. Passing a document with a mix of top level keys + # starting with and without a '$' is invalid and the server will + # raise an appropriate exception. + first = next((iter(document.keys()))) + if first.startswith('$'): + check_keys = False + + return self.__database.connection._send_message( + message.update(self.__full_name, upsert, multi, + spec, document, safe, options, + check_keys, self.__uuid_subtype), safe) + + def drop(self): + """Alias for :meth:`~pymongo.database.Database.drop_collection`. + + The following two calls are equivalent: + + >>> db.foo.drop() + >>> db.drop_collection("foo") + + .. versionadded:: 1.8 + """ + self.__database.drop_collection(self.__name) + + def remove(self, spec_or_id=None, safe=None, **kwargs): + """Remove a document(s) from this collection. + + .. warning:: Calls to :meth:`remove` should be performed with + care, as removed data cannot be restored. + + If `spec_or_id` is ``None``, all documents in this collection + will be removed. This is not equivalent to calling + :meth:`~pymongo.database.Database.drop_collection`, however, + as indexes will not be removed. + + Write concern options can be passed as keyword arguments, overriding + any global defaults. Valid options include w=, + wtimeout=, j=, or fsync=. See the parameter list below + for a detailed explanation of these options. + + By default an acknowledgment is requested from the server that the + remove was successful, raising :class:`~pymongo.errors.OperationFailure` + if an error occurred. **Passing ``w=0`` disables write acknowledgement + and all other write concern options.** + + :Parameters: + - `spec_or_id` (optional): a dictionary specifying the + documents to be removed OR any other type specifying the + value of ``"_id"`` for the document to be removed + - `safe` (optional): **DEPRECATED** - Use `w` instead. + - `w` (optional): (integer or string) If this is a replica set, write + operations will block until they have been replicated to the + specified number or tagged set of servers. `w=` always includes + the replica set primary (e.g. w=3 means write to the primary and wait + until replicated to **two** secondaries). **Passing w=0 disables + write acknowledgement and all other write concern options.** + - `wtimeout` (optional): (integer) Used in conjunction with `w`. + Specify a value in milliseconds to control how long to wait for + write propagation to complete. If replication does not complete in + the given timeframe, a timeout exception is raised. + - `j` (optional): If ``True`` block until write operations have been + committed to the journal. Ignored if the server is running without + journaling. + - `fsync` (optional): If ``True`` force the database to fsync all + files before returning. When used with `j` the server awaits the + next group commit before returning. + :Returns: + - A document (dict) describing the effect of the remove or ``None`` + if write acknowledgement is disabled. + + .. versionadded:: 1.8 + Support for passing `getLastError` options as keyword arguments. + .. versionchanged:: 1.7 Accept any type other than a ``dict`` + instance for removal by ``"_id"``, not just + :class:`~bson.objectid.ObjectId` instances. + .. versionchanged:: 1.4 + Return the response to *lastError* if `safe` is ``True``. + .. versionchanged:: 1.2 + The `spec_or_id` parameter is now optional. If it is + not specified *all* documents in the collection will be + removed. + .. versionadded:: 1.1 + The `safe` parameter. + + .. mongodoc:: remove + """ + if spec_or_id is None: + spec_or_id = {} + if not isinstance(spec_or_id, dict): + spec_or_id = {"_id": spec_or_id} + + safe, options = self._get_write_mode(safe, **kwargs) + return self.__database.connection._send_message( + message.delete(self.__full_name, spec_or_id, safe, + options, self.__uuid_subtype), safe) + + def find_one(self, spec_or_id=None, *args, **kwargs): + """Get a single document from the database. + + All arguments to :meth:`find` are also valid arguments for + :meth:`find_one`, although any `limit` argument will be + ignored. Returns a single document, or ``None`` if no matching + document is found. + + :Parameters: + + - `spec_or_id` (optional): a dictionary specifying + the query to be performed OR any other type to be used as + the value for a query for ``"_id"``. + + - `*args` (optional): any additional positional arguments + are the same as the arguments to :meth:`find`. + + - `**kwargs` (optional): any additional keyword arguments + are the same as the arguments to :meth:`find`. + + .. versionchanged:: 1.7 + Allow passing any of the arguments that are valid for + :meth:`find`. + + .. versionchanged:: 1.7 Accept any type other than a ``dict`` + instance as an ``"_id"`` query, not just + :class:`~bson.objectid.ObjectId` instances. + """ + if spec_or_id is not None and not isinstance(spec_or_id, dict): + spec_or_id = {"_id": spec_or_id} + + for result in self.find(spec_or_id, *args, **kwargs).limit(-1): + return result + return None + + def find(self, *args, **kwargs): + """Query the database. + + The `spec` argument is a prototype document that all results + must match. For example: + + >>> db.test.find({"hello": "world"}) + + only matches documents that have a key "hello" with value + "world". Matches can have other keys *in addition* to + "hello". The `fields` argument is used to specify a subset of + fields that should be included in the result documents. By + limiting results to a certain subset of fields you can cut + down on network traffic and decoding time. + + Raises :class:`TypeError` if any of the arguments are of + improper type. Returns an instance of + :class:`~pymongo.cursor.Cursor` corresponding to this query. + + :Parameters: + - `spec` (optional): a SON object specifying elements which + must be present for a document to be included in the + result set + - `fields` (optional): a list of field names that should be + returned in the result set or a dict specifying the fields + to include or exclude. If `fields` is a list "_id" will + always be returned. Use a dict to exclude fields from + the result (e.g. fields={'_id': False}). + - `skip` (optional): the number of documents to omit (from + the start of the result set) when returning the results + - `limit` (optional): the maximum number of results to + return + - `timeout` (optional): if True (the default), any returned + cursor is closed by the server after 10 minutes of + inactivity. If set to False, the returned cursor will never + time out on the server. Care should be taken to ensure that + cursors with timeout turned off are properly closed. + - `snapshot` (optional): if True, snapshot mode will be used + for this query. Snapshot mode assures no duplicates are + returned, or objects missed, which were present at both + the start and end of the query's execution. For details, + see the `snapshot documentation + `_. + - `tailable` (optional): the result of this find call will + be a tailable cursor - tailable cursors aren't closed when + the last data is retrieved but are kept open and the + cursors location marks the final document's position. if + more data is received iteration of the cursor will + continue from the last document received. For details, see + the `tailable cursor documentation + `_. + - `sort` (optional): a list of (key, direction) pairs + specifying the sort order for this query. See + :meth:`~pymongo.cursor.Cursor.sort` for details. + - `max_scan` (optional): limit the number of documents + examined when performing the query + - `as_class` (optional): class to use for documents in the + query result (default is + :attr:`~pymongo.mongo_client.MongoClient.document_class`) + - `slave_okay` (optional): if True, allows this query to + be run against a replica secondary. + - `await_data` (optional): if True, the server will block for + some extra time before returning, waiting for more data to + return. Ignored if `tailable` is False. + - `partial` (optional): if True, mongos will return partial + results if some shards are down instead of returning an error. + - `manipulate`: (optional): If True (the default), apply any + outgoing SON manipulators before returning. + - `network_timeout` (optional): specify a timeout to use for + this query, which will override the + :class:`~pymongo.mongo_client.MongoClient`-level default + - `read_preference` (optional): The read preference for + this query. + - `tag_sets` (optional): The tag sets for this query. + - `secondary_acceptable_latency_ms` (optional): Any replica-set + member whose ping time is within secondary_acceptable_latency_ms of + the nearest member may accept reads. Default 15 milliseconds. + **Ignored by mongos** and must be configured on the command line. + See the localThreshold_ option for more information. + - `exhaust` (optional): If ``True`` create an "exhaust" cursor. + MongoDB will stream batched results to the client without waiting + for the client to request each batch, reducing latency. + + .. note:: There are a number of caveats to using the `exhaust` + parameter: + + 1. The `exhaust` and `limit` options are incompatible and can + not be used together. + + 2. The `exhaust` option is not supported by mongos and can not be + used with a sharded cluster. + + 3. A :class:`~pymongo.cursor.Cursor` instance created with the + `exhaust` option requires an exclusive :class:`~socket.socket` + connection to MongoDB. If the :class:`~pymongo.cursor.Cursor` is + discarded without being completely iterated the underlying + :class:`~socket.socket` connection will be closed and discarded + without being returned to the connection pool. + + 4. A :class:`~pymongo.cursor.Cursor` instance created with the + `exhaust` option in a :doc:`request ` **must** + be completely iterated before executing any other operation. + + 5. The `network_timeout` option is ignored when using the + `exhaust` option. + + .. note:: The `manipulate` parameter may default to False in + a future release. + + .. note:: The `max_scan` parameter requires server + version **>= 1.5.1** + + .. versionadded:: 2.3 + The `tag_sets` and `secondary_acceptable_latency_ms` parameters. + + .. versionadded:: 1.11+ + The `await_data`, `partial`, and `manipulate` parameters. + + .. versionadded:: 1.8 + The `network_timeout` parameter. + + .. versionadded:: 1.7 + The `sort`, `max_scan` and `as_class` parameters. + + .. versionchanged:: 1.7 + The `fields` parameter can now be a dict or any iterable in + addition to a list. + + .. versionadded:: 1.1 + The `tailable` parameter. + + .. mongodoc:: find + .. _localThreshold: http://docs.mongodb.org/manual/reference/mongos/#cmdoption-mongos--localThreshold + """ + if not 'slave_okay' in kwargs: + kwargs['slave_okay'] = self.slave_okay + if not 'read_preference' in kwargs: + kwargs['read_preference'] = self.read_preference + if not 'tag_sets' in kwargs: + kwargs['tag_sets'] = self.tag_sets + if not 'secondary_acceptable_latency_ms' in kwargs: + kwargs['secondary_acceptable_latency_ms'] = ( + self.secondary_acceptable_latency_ms) + return Cursor(self, *args, **kwargs) + + def count(self): + """Get the number of documents in this collection. + + To get the number of documents matching a specific query use + :meth:`pymongo.cursor.Cursor.count`. + """ + return self.find().count() + + def create_index(self, key_or_list, cache_for=300, **kwargs): + """Creates an index on this collection. + + Takes either a single key or a list of (key, direction) pairs. + The key(s) must be an instance of :class:`basestring` + (:class:`str` in python 3), and the directions must be one of + (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, + :data:`~pymongo.GEO2D`). Returns the name of the created index. + + To create a single key index on the key ``'mike'`` we just use + a string argument: + + >>> my_collection.create_index("mike") + + For a compound index on ``'mike'`` descending and ``'eliot'`` + ascending we need to use a list of tuples: + + >>> my_collection.create_index([("mike", pymongo.DESCENDING), + ... ("eliot", pymongo.ASCENDING)]) + + All optional index creation parameters should be passed as + keyword arguments to this method. Valid options include: + + - `name`: custom name to use for this index - if none is + given, a name will be generated + - `unique`: should this index guarantee uniqueness? + - `dropDups` or `drop_dups`: should we drop duplicates + - `background`: if this index should be created in the + background + - `sparse`: if True, omit from the index any documents that lack + the indexed field + - `bucketSize` or `bucket_size`: for use with geoHaystack indexes. + Number of documents to group together within a certain proximity + to a given longitude and latitude. + - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` + index + - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` + index + - `expireAfterSeconds`: Used to create an expiring (TTL) + collection. MongoDB will automatically delete documents from + this collection after seconds. The indexed field must + be a UTC datetime or the data will not expire. + + .. note:: `expireAfterSeconds` requires server version **>= 2.1.2** + + :Parameters: + - `key_or_list`: a single key or a list of (key, direction) + pairs specifying the index to create + - `cache_for` (optional): time window (in seconds) during which + this index will be recognized by subsequent calls to + :meth:`ensure_index` - see documentation for + :meth:`ensure_index` for details + - `**kwargs` (optional): any additional index creation + options (see the above list) should be passed as keyword + arguments + - `ttl` (deprecated): Use `cache_for` instead. + + .. versionchanged:: 2.3 + The `ttl` parameter has been deprecated to avoid confusion with + TTL collections. Use `cache_for` instead. + + .. versionchanged:: 2.2 + Removed deprecated argument: deprecated_unique + + .. versionchanged:: 1.5.1 + Accept kwargs to support all index creation options. + + .. versionadded:: 1.5 + The `name` parameter. + + .. seealso:: :meth:`ensure_index` + + .. mongodoc:: indexes + """ + + if 'ttl' in kwargs: + cache_for = kwargs.pop('ttl') + warnings.warn("ttl is deprecated. Please use cache_for instead.", + DeprecationWarning, stacklevel=2) + + # The types supported by datetime.timedelta. 2to3 removes long. + if not isinstance(cache_for, (int, float)): + raise TypeError("cache_for must be an integer or float.") + + keys = helpers._index_list(key_or_list) + index_doc = helpers._index_document(keys) + + index = {"key": index_doc, "ns": self.__full_name} + + name = "name" in kwargs and kwargs["name"] or _gen_index_name(keys) + index["name"] = name + + if "drop_dups" in kwargs: + kwargs["dropDups"] = kwargs.pop("drop_dups") + + if "bucket_size" in kwargs: + kwargs["bucketSize"] = kwargs.pop("bucket_size") + + index.update(kwargs) + + self.__database.system.indexes.insert(index, manipulate=False, + check_keys=False, + **self._get_wc_override()) + + self.__database.connection._cache_index(self.__database.name, + self.__name, name, cache_for) + + return name + + def ensure_index(self, key_or_list, cache_for=300, **kwargs): + """Ensures that an index exists on this collection. + + Takes either a single key or a list of (key, direction) pairs. + The key(s) must be an instance of :class:`basestring` + (:class:`str` in python 3), and the direction(s) must be one of + (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, + :data:`~pymongo.GEO2D`). See :meth:`create_index` for a detailed + example. + + Unlike :meth:`create_index`, which attempts to create an index + unconditionally, :meth:`ensure_index` takes advantage of some + caching within the driver such that it only attempts to create + indexes that might not already exist. When an index is created + (or ensured) by PyMongo it is "remembered" for `cache_for` + seconds. Repeated calls to :meth:`ensure_index` within that + time limit will be lightweight - they will not attempt to + actually create the index. + + Care must be taken when the database is being accessed through + multiple clients at once. If an index is created using + this client and deleted using another, any call to + :meth:`ensure_index` within the cache window will fail to + re-create the missing index. + + Returns the name of the created index if an index is actually + created. Returns ``None`` if the index already exists. + + All optional index creation parameters should be passed as + keyword arguments to this method. Valid options include: + + - `name`: custom name to use for this index - if none is + given, a name will be generated + - `unique`: should this index guarantee uniqueness? + - `dropDups` or `drop_dups`: should we drop duplicates + during index creation when creating a unique index? + - `background`: if this index should be created in the + background + - `sparse`: if True, omit from the index any documents that lack + the indexed field + - `bucketSize` or `bucket_size`: for use with geoHaystack indexes. + Number of documents to group together within a certain proximity + to a given longitude and latitude. + - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` + index + - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` + index + - `expireAfterSeconds`: Used to create an expiring (TTL) + collection. MongoDB will automatically delete documents from + this collection after seconds. The indexed field must + be a UTC datetime or the data will not expire. + + .. note:: `expireAfterSeconds` requires server version **>= 2.1.2** + + :Parameters: + - `key_or_list`: a single key or a list of (key, direction) + pairs specifying the index to create + - `cache_for` (optional): time window (in seconds) during which + this index will be recognized by subsequent calls to + :meth:`ensure_index` + - `**kwargs` (optional): any additional index creation + options (see the above list) should be passed as keyword + arguments + - `ttl` (deprecated): Use `cache_for` instead. + + .. versionchanged:: 2.3 + The `ttl` parameter has been deprecated to avoid confusion with + TTL collections. Use `cache_for` instead. + + .. versionchanged:: 2.2 + Removed deprecated argument: deprecated_unique + + .. versionchanged:: 1.5.1 + Accept kwargs to support all index creation options. + + .. versionadded:: 1.5 + The `name` parameter. + + .. seealso:: :meth:`create_index` + """ + if "name" in kwargs: + name = kwargs["name"] + else: + keys = helpers._index_list(key_or_list) + name = kwargs["name"] = _gen_index_name(keys) + + if not self.__database.connection._cached(self.__database.name, + self.__name, name): + return self.create_index(key_or_list, cache_for, **kwargs) + return None + + def drop_indexes(self): + """Drops all indexes on this collection. + + Can be used on non-existant collections or collections with no indexes. + Raises OperationFailure on an error. + """ + self.__database.connection._purge_index(self.__database.name, + self.__name) + self.drop_index("*") + + def drop_index(self, index_or_name): + """Drops the specified index on this collection. + + Can be used on non-existant collections or collections with no + indexes. Raises OperationFailure on an error. `index_or_name` + can be either an index name (as returned by `create_index`), + or an index specifier (as passed to `create_index`). An index + specifier should be a list of (key, direction) pairs. Raises + TypeError if index is not an instance of (str, unicode, list). + + .. warning:: + + if a custom name was used on index creation (by + passing the `name` parameter to :meth:`create_index` or + :meth:`ensure_index`) the index **must** be dropped by name. + + :Parameters: + - `index_or_name`: index (or name of index) to drop + """ + name = index_or_name + if isinstance(index_or_name, list): + name = _gen_index_name(index_or_name) + + if not isinstance(name, str): + raise TypeError("index_or_name must be an index name or list") + + self.__database.connection._purge_index(self.__database.name, + self.__name, name) + self.__database.command("dropIndexes", self.__name, index=name, + allowable_errors=["ns not found"]) + + def reindex(self): + """Rebuilds all indexes on this collection. + + .. warning:: reindex blocks all other operations (indexes + are built in the foreground) and will be slow for large + collections. + + .. versionadded:: 1.11+ + """ + return self.__database.command("reIndex", self.__name) + + def index_information(self): + """Get information on this collection's indexes. + + Returns a dictionary where the keys are index names (as + returned by create_index()) and the values are dictionaries + containing information about each index. The dictionary is + guaranteed to contain at least a single key, ``"key"`` which + is a list of (key, direction) pairs specifying the index (as + passed to create_index()). It will also contain any other + information in `system.indexes`, except for the ``"ns"`` and + ``"name"`` keys, which are cleaned. Example output might look + like this: + + >>> db.test.ensure_index("x", unique=True) + u'x_1' + >>> db.test.index_information() + {u'_id_': {u'key': [(u'_id', 1)]}, + u'x_1': {u'unique': True, u'key': [(u'x', 1)]}} + + + .. versionchanged:: 1.7 + The values in the resultant dictionary are now dictionaries + themselves, whose ``"key"`` item contains the list that was + the value in previous versions of PyMongo. + """ + raw = self.__database.system.indexes.find({"ns": self.__full_name}, + {"ns": 0}, as_class=SON) + info = {} + for index in raw: + index["key"] = list(index["key"].items()) + index = dict(index) + info[index.pop("name")] = index + return info + + def options(self): + """Get the options set on this collection. + + Returns a dictionary of options and their values - see + :meth:`~pymongo.database.Database.create_collection` for more + information on the possible options. Returns an empty + dictionary if the collection has not been created yet. + """ + result = self.__database.system.namespaces.find_one( + {"name": self.__full_name}) + + if not result: + return {} + + options = result.get("options", {}) + if "create" in options: + del options["create"] + + return options + + def aggregate(self, pipeline, **kwargs): + """Perform an aggregation using the aggregation framework on this + collection. + + With :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` + or :class:`~pymongo.master_slave_connection.MasterSlaveConnection`, + if the `read_preference` attribute of this instance is not set to + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY` or the + (deprecated) `slave_okay` attribute of this instance is set to `True` + the `aggregate command`_ will be sent to a secondary or slave. + + :Parameters: + - `pipeline`: a single command or list of aggregation commands + - `**kwargs`: send arbitrary parameters to the aggregate command + + .. note:: Requires server version **>= 2.1.0**. + + With server version **>= 2.5.1**, pass + ``cursor={}`` to retrieve unlimited aggregation results + with a :class:`~pymongo.cursor.Cursor`:: + + pipeline = [{'$project': {'name': {'$toUpper': '$name'}}}] + cursor = collection.aggregate(pipeline, cursor={}) + for doc in cursor: + print doc + + .. versionchanged:: 2.6 + Added cursor support. + .. versionadded:: 2.3 + + .. _aggregate command: + http://docs.mongodb.org/manual/applications/aggregation + """ + if not isinstance(pipeline, (dict, list, tuple)): + raise TypeError("pipeline must be a dict, list or tuple") + + if isinstance(pipeline, dict): + pipeline = [pipeline] + + use_master = not self.slave_okay and not self.read_preference + + command_kwargs = { + 'pipeline': pipeline, + 'read_preference': self.read_preference, + 'tag_sets': self.tag_sets, + 'secondary_acceptable_latency_ms': ( + self.secondary_acceptable_latency_ms), + 'slave_okay': self.slave_okay, + '_use_master': use_master} + + command_kwargs.update(kwargs) + command_response = self.__database.command( + "aggregate", self.__name, **command_kwargs) + + if 'cursor' in command_response: + cursor_info = command_response['cursor'] + return Cursor( + self, + _first_batch=cursor_info['firstBatch'], + _cursor_id=cursor_info['id']) + else: + return command_response + + # TODO key and condition ought to be optional, but deprecation + # could be painful as argument order would have to change. + def group(self, key, condition, initial, reduce, finalize=None): + """Perform a query similar to an SQL *group by* operation. + + Returns an array of grouped items. + + The `key` parameter can be: + + - ``None`` to use the entire document as a key. + - A :class:`list` of keys (each a :class:`basestring` + (:class:`str` in python 3)) to group by. + - A :class:`basestring` (:class:`str` in python 3), or + :class:`~bson.code.Code` instance containing a JavaScript + function to be applied to each document, returning the key + to group by. + + With :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` + or :class:`~pymongo.master_slave_connection.MasterSlaveConnection`, + if the `read_preference` attribute of this instance is not set to + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY` or + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY_PREFERRED`, or + the (deprecated) `slave_okay` attribute of this instance is set to + `True`, the group command will be sent to a secondary or slave. + + :Parameters: + - `key`: fields to group by (see above description) + - `condition`: specification of rows to be + considered (as a :meth:`find` query specification) + - `initial`: initial value of the aggregation counter object + - `reduce`: aggregation function as a JavaScript string + - `finalize`: function to be called on each object in output list. + + .. versionchanged:: 2.2 + Removed deprecated argument: command + + .. versionchanged:: 1.4 + The `key` argument can now be ``None`` or a JavaScript function, + in addition to a :class:`list` of keys. + + .. versionchanged:: 1.3 + The `command` argument now defaults to ``True`` and is deprecated. + """ + + group = {} + if isinstance(key, str): + group["$keyf"] = Code(key) + elif key is not None: + group = {"key": helpers._fields_list_to_dict(key)} + group["ns"] = self.__name + group["$reduce"] = Code(reduce) + group["cond"] = condition + group["initial"] = initial + if finalize is not None: + group["finalize"] = Code(finalize) + + use_master = not self.slave_okay and not self.read_preference + + return self.__database.command("group", group, + uuid_subtype=self.__uuid_subtype, + read_preference=self.read_preference, + tag_sets=self.tag_sets, + secondary_acceptable_latency_ms=( + self.secondary_acceptable_latency_ms), + slave_okay=self.slave_okay, + _use_master=use_master)["retval"] + + def rename(self, new_name, **kwargs): + """Rename this collection. + + If operating in auth mode, client must be authorized as an + admin to perform this operation. Raises :class:`TypeError` if + `new_name` is not an instance of :class:`basestring` + (:class:`str` in python 3). Raises :class:`~pymongo.errors.InvalidName` + if `new_name` is not a valid collection name. + + :Parameters: + - `new_name`: new name for this collection + - `**kwargs` (optional): any additional rename options + should be passed as keyword arguments + (i.e. ``dropTarget=True``) + + .. versionadded:: 1.7 + support for accepting keyword arguments for rename options + """ + if not isinstance(new_name, str): + raise TypeError("new_name must be an instance " + "of %s" % (str.__name__,)) + + if not new_name or ".." in new_name: + raise InvalidName("collection names cannot be empty") + if new_name[0] == "." or new_name[-1] == ".": + raise InvalidName("collecion names must not start or end with '.'") + if "$" in new_name and not new_name.startswith("oplog.$main"): + raise InvalidName("collection names must not contain '$'") + + new_name = "%s.%s" % (self.__database.name, new_name) + self.__database.connection.admin.command("renameCollection", + self.__full_name, + to=new_name, **kwargs) + + def distinct(self, key): + """Get a list of distinct values for `key` among all documents + in this collection. + + Raises :class:`TypeError` if `key` is not an instance of + :class:`basestring` (:class:`str` in python 3). + + To get the distinct values for a key in the result set of a + query use :meth:`~pymongo.cursor.Cursor.distinct`. + + :Parameters: + - `key`: name of key for which we want to get the distinct values + + .. note:: Requires server version **>= 1.1.0** + + .. versionadded:: 1.1.1 + """ + return self.find().distinct(key) + + def map_reduce(self, map, reduce, out, full_response=False, **kwargs): + """Perform a map/reduce operation on this collection. + + If `full_response` is ``False`` (default) returns a + :class:`~pymongo.collection.Collection` instance containing + the results of the operation. Otherwise, returns the full + response from the server to the `map reduce command`_. + + :Parameters: + - `map`: map function (as a JavaScript string) + - `reduce`: reduce function (as a JavaScript string) + - `out`: output collection name or `out object` (dict). See + the `map reduce command`_ documentation for available options. + Note: `out` options are order sensitive. :class:`~bson.son.SON` + can be used to specify multiple options. + e.g. SON([('replace', ), ('db', )]) + - `full_response` (optional): if ``True``, return full response to + this command - otherwise just return the result collection + - `**kwargs` (optional): additional arguments to the + `map reduce command`_ may be passed as keyword arguments to this + helper method, e.g.:: + + >>> db.test.map_reduce(map, reduce, "myresults", limit=2) + + .. note:: Requires server version **>= 1.1.1** + + .. seealso:: :doc:`/examples/aggregation` + + .. versionchanged:: 2.2 + Removed deprecated arguments: merge_output and reduce_output + + .. versionchanged:: 1.11+ + DEPRECATED The merge_output and reduce_output parameters. + + .. versionadded:: 1.2 + + .. _map reduce command: http://www.mongodb.org/display/DOCS/MapReduce + + .. mongodoc:: mapreduce + """ + if not isinstance(out, (str, dict)): + raise TypeError("'out' must be an instance of " + "%s or dict" % (str.__name__,)) + + if isinstance(out, dict) and out.get('inline'): + must_use_master = False + else: + must_use_master = True + + response = self.__database.command("mapreduce", self.__name, + uuid_subtype=self.__uuid_subtype, + map=map, reduce=reduce, + read_preference=self.read_preference, + tag_sets=self.tag_sets, + secondary_acceptable_latency_ms=( + self.secondary_acceptable_latency_ms), + out=out, _use_master=must_use_master, + **kwargs) + + if full_response or not response.get('result'): + return response + elif isinstance(response['result'], dict): + dbase = response['result']['db'] + coll = response['result']['collection'] + return self.__database.connection[dbase][coll] + else: + return self.__database[response["result"]] + + def inline_map_reduce(self, map, reduce, full_response=False, **kwargs): + """Perform an inline map/reduce operation on this collection. + + Perform the map/reduce operation on the server in RAM. A result + collection is not created. The result set is returned as a list + of documents. + + If `full_response` is ``False`` (default) returns the + result documents in a list. Otherwise, returns the full + response from the server to the `map reduce command`_. + + With :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` + or :class:`~pymongo.master_slave_connection.MasterSlaveConnection`, + if the `read_preference` attribute of this instance is not set to + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY` or + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY_PREFERRED`, or + the (deprecated) `slave_okay` attribute of this instance is set to + `True`, the inline map reduce will be run on a secondary or slave. + + :Parameters: + - `map`: map function (as a JavaScript string) + - `reduce`: reduce function (as a JavaScript string) + - `full_response` (optional): if ``True``, return full response to + this command - otherwise just return the result collection + - `**kwargs` (optional): additional arguments to the + `map reduce command`_ may be passed as keyword arguments to this + helper method, e.g.:: + + >>> db.test.inline_map_reduce(map, reduce, limit=2) + + .. note:: Requires server version **>= 1.7.4** + + .. versionadded:: 1.10 + """ + + use_master = not self.slave_okay and not self.read_preference + + res = self.__database.command("mapreduce", self.__name, + uuid_subtype=self.__uuid_subtype, + read_preference=self.read_preference, + tag_sets=self.tag_sets, + secondary_acceptable_latency_ms=( + self.secondary_acceptable_latency_ms), + slave_okay=self.slave_okay, + _use_master=use_master, + map=map, reduce=reduce, + out={"inline": 1}, **kwargs) + + if full_response: + return res + else: + return res.get("results") + + def find_and_modify(self, query={}, update=None, + upsert=False, sort=None, full_response=False, **kwargs): + """Update and return an object. + + This is a thin wrapper around the findAndModify_ command. The + positional arguments are designed to match the first three arguments + to :meth:`update` however most options should be passed as named + parameters. Either `update` or `remove` arguments are required, all + others are optional. + + Returns either the object before or after modification based on `new` + parameter. If no objects match the `query` and `upsert` is false, + returns ``None``. If upserting and `new` is false, returns ``{}``. + + If the full_response parameter is ``True``, the return value will be + the entire response object from the server, including the 'ok' and + 'lastErrorObject' fields, rather than just the modified object. + This is useful mainly because the 'lastErrorObject' document holds + information about the command's execution. + + :Parameters: + - `query`: filter for the update (default ``{}``) + - `update`: see second argument to :meth:`update` (no default) + - `upsert`: insert if object doesn't exist (default ``False``) + - `sort`: a list of (key, direction) pairs specifying the sort + order for this query. See :meth:`~pymongo.cursor.Cursor.sort` + for details. + - `full_response`: return the entire response object from the + server (default ``False``) + - `remove`: remove rather than updating (default ``False``) + - `new`: return updated rather than original object + (default ``False``) + - `fields`: see second argument to :meth:`find` (default all) + - `**kwargs`: any other options the findAndModify_ command + supports can be passed here. + + + .. mongodoc:: findAndModify + + .. _findAndModify: http://dochub.mongodb.org/core/findAndModify + + .. note:: Requires server version **>= 1.3.0** + + .. versionchanged:: 2.5 + Added the optional full_response parameter + + .. versionchanged:: 2.4 + Deprecated the use of mapping types for the sort parameter + + .. versionadded:: 1.10 + """ + if (not update and not kwargs.get('remove', None)): + raise ValueError("Must either update or remove") + + if (update and kwargs.get('remove', None)): + raise ValueError("Can't do both update and remove") + + # No need to include empty args + if query: + kwargs['query'] = query + if update: + kwargs['update'] = update + if upsert: + kwargs['upsert'] = upsert + if sort: + # Accept a list of tuples to match Cursor's sort parameter. + if isinstance(sort, list): + kwargs['sort'] = helpers._index_document(sort) + # Accept OrderedDict, SON, and dict with len == 1 so we + # don't break existing code already using find_and_modify. + elif (isinstance(sort, ordered_types) or + isinstance(sort, dict) and len(sort) == 1): + warnings.warn("Passing mapping types for `sort` is deprecated," + " use a list of (key, direction) pairs instead", + DeprecationWarning, stacklevel=2) + kwargs['sort'] = sort + else: + raise TypeError("sort must be a list of (key, direction) " + "pairs, a dict of len 1, or an instance of " + "SON or OrderedDict") + + no_obj_error = "No matching object found" + + out = self.__database.command("findAndModify", self.__name, + allowable_errors=[no_obj_error], + uuid_subtype=self.__uuid_subtype, + **kwargs) + + if not out['ok']: + if out["errmsg"] == no_obj_error: + return None + else: + # Should never get here b/c of allowable_errors + raise ValueError("Unexpected Error: %s" % (out,)) + + if full_response: + return out + else: + return out.get('value') + + def __iter__(self): + return self + + def __next__(self): + raise TypeError("'Collection' object is not iterable") + + def __call__(self, *args, **kwargs): + """This is only here so that some API misusages are easier to debug. + """ + if "." not in self.__name: + raise TypeError("'Collection' object is not callable. If you " + "meant to call the '%s' method on a 'Database' " + "object it is failing because no such method " + "exists." % + self.__name) + raise TypeError("'Collection' object is not callable. If you meant to " + "call the '%s' method on a 'Collection' object it is " + "failing because no such method exists." % + self.__name.split(".")[-1]) diff --git a/asyncio_mongo/_pymongo/common.py b/asyncio_mongo/_pymongo/common.py new file mode 100644 index 0000000..81b5a70 --- /dev/null +++ b/asyncio_mongo/_pymongo/common.py @@ -0,0 +1,646 @@ +# Copyright 2011-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Functions and classes common to multiple pymongo modules.""" +import sys +import warnings +from asyncio_mongo._pymongo import read_preferences + +from asyncio_mongo._pymongo.auth import MECHANISMS +from asyncio_mongo._pymongo.read_preferences import ReadPreference +from asyncio_mongo._pymongo.errors import ConfigurationError + +HAS_SSL = True +try: + import ssl +except ImportError: + HAS_SSL = False + + +# Jython 2.7 includes an incomplete ssl module. See PYTHON-498. +if sys.platform.startswith('java'): + HAS_SSL = False + + +def raise_config_error(key, dummy): + """Raise ConfigurationError with the given key name.""" + raise ConfigurationError("Unknown option %s" % (key,)) + + +def validate_boolean(option, value): + """Validates that 'value' is 'true' or 'false'. + """ + if isinstance(value, bool): + return value + elif isinstance(value, str): + if value not in ('true', 'false'): + raise ConfigurationError("The value of %s must be " + "'true' or 'false'" % (option,)) + return value == 'true' + raise TypeError("Wrong type for %s, value must be a boolean" % (option,)) + + +def validate_integer(option, value): + """Validates that 'value' is an integer (or basestring representation). + """ + if isinstance(value, int): + return value + elif isinstance(value, str): + if not value.isdigit(): + raise ConfigurationError("The value of %s must be " + "an integer" % (option,)) + return int(value) + raise TypeError("Wrong type for %s, value must be an integer" % (option,)) + + +def validate_positive_integer(option, value): + """Validate that 'value' is a positive integer. + """ + val = validate_integer(option, value) + if val < 0: + raise ConfigurationError("The value of %s must be " + "a positive integer" % (option,)) + return val + + +def validate_readable(option, value): + """Validates that 'value' is file-like and readable. + """ + # First make sure its a string py3.3 open(True, 'r') succeeds + # Used in ssl cert checking due to poor ssl module error reporting + value = validate_basestring(option, value) + open(value, 'r').close() + return value + + +def validate_cert_reqs(option, value): + """Validate the cert reqs are valid. It must be None or one of the three + values ``ssl.CERT_NONE``, ``ssl.CERT_OPTIONAL`` or ``ssl.CERT_REQUIRED``""" + if value is None: + return value + if HAS_SSL: + if value in (ssl.CERT_NONE, ssl.CERT_OPTIONAL, ssl.CERT_REQUIRED): + return value + raise ConfigurationError("The value of %s must be one of: " + "`ssl.CERT_NONE`, `ssl.CERT_OPTIONAL` or " + "`ssl.CERT_REQUIRED" % (option,)) + else: + raise ConfigurationError("The value of %s is set but can't be " + "validated. The ssl module is not available" + % (option,)) + + +def validate_positive_integer_or_none(option, value): + """Validate that 'value' is a positive integer or None. + """ + if value is None: + return value + return validate_positive_integer(option, value) + + +def validate_basestring(option, value): + """Validates that 'value' is an instance of `basestring`. + """ + if isinstance(value, str): + return value + raise TypeError("Wrong type for %s, value must be an " + "instance of %s" % (option, str.__name__)) + + +def validate_int_or_basestring(option, value): + """Validates that 'value' is an integer or string. + """ + if isinstance(value, int): + return value + elif isinstance(value, str): + if value.isdigit(): + return int(value) + return value + raise TypeError("Wrong type for %s, value must be an " + "integer or a string" % (option,)) + + +def validate_positive_float(option, value): + """Validates that 'value' is a float, or can be converted to one, and is + positive. + """ + err = ConfigurationError("%s must be a positive int or float" % (option,)) + try: + value = float(value) + except (ValueError, TypeError): + raise err + + # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at + # one billion - this is a reasonable approximation for infinity + if not 0 < value < 1e9: + raise err + + return value + + +def validate_timeout_or_none(option, value): + """Validates a timeout specified in milliseconds returning + a value in floating point seconds. + """ + if value is None: + return value + return validate_positive_float(option, value) / 1000.0 + + +def validate_read_preference(dummy, value): + """Validate read preference for a ReplicaSetConnection. + """ + if value in read_preferences.modes: + return value + + # Also allow string form of enum for uri_parser + try: + return read_preferences.mongos_enum(value) + except ValueError: + raise ConfigurationError("Not a valid read preference") + + +def validate_tag_sets(dummy, value): + """Validate tag sets for a ReplicaSetConnection. + """ + if value is None: + return [{}] + + if not isinstance(value, list): + raise ConfigurationError(( + "Tag sets %s invalid, must be a list" ) % repr(value)) + if len(value) == 0: + raise ConfigurationError(( + "Tag sets %s invalid, must be None or contain at least one set of" + " tags") % repr(value)) + + for tags in value: + if not isinstance(tags, dict): + raise ConfigurationError( + "Tag set %s invalid, must be a dict" % repr(tags)) + + return value + + +def validate_auth_mechanism(option, value): + """Validate the authMechanism URI option. + """ + if value not in MECHANISMS: + raise ConfigurationError("%s must be in " + "%s" % (option, MECHANISMS)) + return value + + +# jounal is an alias for j, +# wtimeoutms is an alias for wtimeout +VALIDATORS = { + 'replicaset': validate_basestring, + 'slaveok': validate_boolean, + 'slave_okay': validate_boolean, + 'safe': validate_boolean, + 'w': validate_int_or_basestring, + 'wtimeout': validate_integer, + 'wtimeoutms': validate_integer, + 'fsync': validate_boolean, + 'j': validate_boolean, + 'journal': validate_boolean, + 'connecttimeoutms': validate_timeout_or_none, + 'sockettimeoutms': validate_timeout_or_none, + 'waitqueuetimeoutms': validate_timeout_or_none, + 'waitqueuemultiple': validate_positive_integer_or_none, + 'ssl': validate_boolean, + 'ssl_keyfile': validate_readable, + 'ssl_certfile': validate_readable, + 'ssl_cert_reqs': validate_cert_reqs, + 'ssl_ca_certs': validate_readable, + 'readpreference': validate_read_preference, + 'read_preference': validate_read_preference, + 'tag_sets': validate_tag_sets, + 'secondaryacceptablelatencyms': validate_positive_float, + 'secondary_acceptable_latency_ms': validate_positive_float, + 'auto_start_request': validate_boolean, + 'use_greenlets': validate_boolean, + 'authmechanism': validate_auth_mechanism, + 'authsource': validate_basestring, + 'gssapiservicename': validate_basestring, +} + + +_AUTH_OPTIONS = frozenset(['gssapiservicename']) + + +def validate_auth_option(option, value): + """Validate optional authentication parameters. + """ + lower, value = validate(option, value) + if lower not in _AUTH_OPTIONS: + raise ConfigurationError('Unknown ' + 'authentication option: %s' % (option,)) + return lower, value + + +def validate(option, value): + """Generic validation function. + """ + lower = option.lower() + validator = VALIDATORS.get(lower, raise_config_error) + value = validator(option, value) + return lower, value + + +SAFE_OPTIONS = frozenset([ + 'w', + 'wtimeout', + 'wtimeoutms', + 'fsync', + 'j', + 'journal' +]) + + +class WriteConcern(dict): + + def __init__(self, *args, **kwargs): + """A subclass of dict that overrides __setitem__ to + validate write concern options. + """ + super(WriteConcern, self).__init__(*args, **kwargs) + + def __setitem__(self, key, value): + if key not in SAFE_OPTIONS: + raise ConfigurationError("%s is not a valid write " + "concern option." % (key,)) + key, value = validate(key, value) + super(WriteConcern, self).__setitem__(key, value) + + +class BaseObject(object): + """A base class that provides attributes and methods common + to multiple pymongo classes. + + SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO 10GEN + """ + + def __init__(self, **options): + + self.__slave_okay = False + self.__read_pref = ReadPreference.PRIMARY + self.__tag_sets = [{}] + self.__secondary_acceptable_latency_ms = 15 + self.__safe = None + self.__write_concern = WriteConcern() + self.__set_options(options) + if (self.__read_pref == ReadPreference.PRIMARY + and self.__tag_sets != [{}] + ): + raise ConfigurationError( + "ReadPreference PRIMARY cannot be combined with tags") + + # If safe hasn't been implicitly set by write concerns then set it. + if self.__safe is None: + if options.get("w") == 0: + self.__safe = False + else: + self.__safe = validate_boolean('safe', options.get("safe", True)) + # Note: 'safe' is always passed by Connection and ReplicaSetConnection + # Always do the most "safe" thing, but warn about conflicts. + if self.__safe and options.get('w') == 0: + warnings.warn("Conflicting write concerns. 'w' set to 0 " + "but other options have enabled write concern. " + "Please set 'w' to a value other than 0.", + UserWarning) + + def __set_safe_option(self, option, value): + """Validates and sets getlasterror options for this + object (Connection, Database, Collection, etc.) + """ + if value is None: + self.__write_concern.pop(option, None) + else: + self.__write_concern[option] = value + if option != "w" or value != 0: + self.__safe = True + + def __set_options(self, options): + """Validates and sets all options passed to this object.""" + for option, value in options.items(): + if option in ('slave_okay', 'slaveok'): + self.__slave_okay = validate_boolean(option, value) + elif option in ('read_preference', "readpreference"): + self.__read_pref = validate_read_preference(option, value) + elif option == 'tag_sets': + self.__tag_sets = validate_tag_sets(option, value) + elif option in ( + 'secondaryacceptablelatencyms', + 'secondary_acceptable_latency_ms' + ): + self.__secondary_acceptable_latency_ms = \ + validate_positive_float(option, value) + elif option in SAFE_OPTIONS: + if option == 'journal': + self.__set_safe_option('j', value) + elif option == 'wtimeoutms': + self.__set_safe_option('wtimeout', value) + else: + self.__set_safe_option(option, value) + + def __set_write_concern(self, value): + """Property setter for write_concern.""" + if not isinstance(value, dict): + raise ConfigurationError("write_concern must be an " + "instance of dict or a subclass.") + # Make a copy here to avoid users accidentally setting the + # same dict on multiple instances. + wc = WriteConcern() + for k, v in value.items(): + # Make sure we validate each option. + wc[k] = v + self.__write_concern = wc + + def __get_write_concern(self): + """The default write concern for this instance. + + Supports dict style access for getting/setting write concern + options. Valid options include: + + - `w`: (integer or string) If this is a replica set, write operations + will block until they have been replicated to the specified number + or tagged set of servers. `w=` always includes the replica set + primary (e.g. w=3 means write to the primary and wait until + replicated to **two** secondaries). **Setting w=0 disables write + acknowledgement and all other write concern options.** + - `wtimeout`: (integer) Used in conjunction with `w`. Specify a value + in milliseconds to control how long to wait for write propagation + to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. + - `j`: If ``True`` block until write operations have been committed + to the journal. Ignored if the server is running without journaling. + - `fsync`: If ``True`` force the database to fsync all files before + returning. When used with `j` the server awaits the next group + commit before returning. + + >>> m = pymongo.MongoClient() + >>> m.write_concern + {} + >>> m.write_concern = {'w': 2, 'wtimeout': 1000} + >>> m.write_concern + {'wtimeout': 1000, 'w': 2} + >>> m.write_concern['j'] = True + >>> m.write_concern + {'wtimeout': 1000, 'j': True, 'w': 2} + >>> m.write_concern = {'j': True} + >>> m.write_concern + {'j': True} + >>> # Disable write acknowledgement and write concern + ... + >>> m.write_concern['w'] = 0 + + + .. note:: Accessing :attr:`write_concern` returns its value + (a subclass of :class:`dict`), not a copy. + + .. warning:: If you are using :class:`~pymongo.connection.Connection` + or :class:`~pymongo.replica_set_connection.ReplicaSetConnection` + make sure you explicitly set ``w`` to 1 (or a greater value) or + :attr:`safe` to ``True``. Unlike calling + :meth:`set_lasterror_options`, setting an option in + :attr:`write_concern` does not implicitly set :attr:`safe` + to ``True``. + """ + # To support dict style access we have to return the actual + # WriteConcern here, not a copy. + return self.__write_concern + + write_concern = property(__get_write_concern, __set_write_concern) + + def __get_slave_okay(self): + """DEPRECATED. Use :attr:`read_preference` instead. + + .. versionchanged:: 2.1 + Deprecated slave_okay. + .. versionadded:: 2.0 + """ + return self.__slave_okay + + def __set_slave_okay(self, value): + """Property setter for slave_okay""" + warnings.warn("slave_okay is deprecated. Please use " + "read_preference instead.", DeprecationWarning, + stacklevel=2) + self.__slave_okay = validate_boolean('slave_okay', value) + + slave_okay = property(__get_slave_okay, __set_slave_okay) + + def __get_read_pref(self): + """The read preference mode for this instance. + + See :class:`~pymongo.read_preferences.ReadPreference` for available options. + + .. versionadded:: 2.1 + """ + return self.__read_pref + + def __set_read_pref(self, value): + """Property setter for read_preference""" + self.__read_pref = validate_read_preference('read_preference', value) + + read_preference = property(__get_read_pref, __set_read_pref) + + def __get_acceptable_latency(self): + """Any replica-set member whose ping time is within + secondary_acceptable_latency_ms of the nearest member may accept + reads. Defaults to 15 milliseconds. + + See :class:`~pymongo.read_preferences.ReadPreference`. + + .. versionadded:: 2.3 + + .. note:: ``secondary_acceptable_latency_ms`` is ignored when talking to a + replica set *through* a mongos. The equivalent is the localThreshold_ command + line option. + + .. _localThreshold: http://docs.mongodb.org/manual/reference/mongos/#cmdoption-mongos--localThreshold + """ + return self.__secondary_acceptable_latency_ms + + def __set_acceptable_latency(self, value): + """Property setter for secondary_acceptable_latency_ms""" + self.__secondary_acceptable_latency_ms = (validate_positive_float( + 'secondary_acceptable_latency_ms', value)) + + secondary_acceptable_latency_ms = property( + __get_acceptable_latency, __set_acceptable_latency) + + def __get_tag_sets(self): + """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to + read only from members whose ``dc`` tag has the value ``"ny"``. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." ReplicaSetConnection tries each set of tags in turn + until it finds a set of tags with at least one matching member. + + .. seealso:: `Data-Center Awareness + `_ + + .. versionadded:: 2.3 + """ + return self.__tag_sets + + def __set_tag_sets(self, value): + """Property setter for tag_sets""" + self.__tag_sets = validate_tag_sets('tag_sets', value) + + tag_sets = property(__get_tag_sets, __set_tag_sets) + + def __get_safe(self): + """**DEPRECATED:** Use the 'w' :attr:`write_concern` option instead. + + Use getlasterror with every write operation? + + .. versionadded:: 2.0 + """ + return self.__safe + + def __set_safe(self, value): + """Property setter for safe""" + warnings.warn("safe is deprecated. Please use the" + " 'w' write_concern option instead.", + DeprecationWarning, stacklevel=2) + self.__safe = validate_boolean('safe', value) + + safe = property(__get_safe, __set_safe) + + def get_lasterror_options(self): + """DEPRECATED: Use :attr:`write_concern` instead. + + Returns a dict of the getlasterror options set on this instance. + + .. versionchanged:: 2.4 + Deprecated get_lasterror_options. + .. versionadded:: 2.0 + """ + warnings.warn("get_lasterror_options is deprecated. Please use " + "write_concern instead.", DeprecationWarning, + stacklevel=2) + return self.__write_concern.copy() + + def set_lasterror_options(self, **kwargs): + """DEPRECATED: Use :attr:`write_concern` instead. + + Set getlasterror options for this instance. + + Valid options include j=, w=, wtimeout=, + and fsync=. Implies safe=True. + + :Parameters: + - `**kwargs`: Options should be passed as keyword + arguments (e.g. w=2, fsync=True) + + .. versionchanged:: 2.4 + Deprecated set_lasterror_options. + .. versionadded:: 2.0 + """ + warnings.warn("set_lasterror_options is deprecated. Please use " + "write_concern instead.", DeprecationWarning, + stacklevel=2) + for key, value in kwargs.items(): + self.__set_safe_option(key, value) + + def unset_lasterror_options(self, *options): + """DEPRECATED: Use :attr:`write_concern` instead. + + Unset getlasterror options for this instance. + + If no options are passed unsets all getlasterror options. + This does not set `safe` to False. + + :Parameters: + - `*options`: The list of options to unset. + + .. versionchanged:: 2.4 + Deprecated unset_lasterror_options. + .. versionadded:: 2.0 + """ + warnings.warn("unset_lasterror_options is deprecated. Please use " + "write_concern instead.", DeprecationWarning, + stacklevel=2) + if len(options): + for option in options: + self.__write_concern.pop(option, None) + else: + self.__write_concern = WriteConcern() + + def _get_wc_override(self): + """Get write concern override. + + Used in internal methods that **must** do acknowledged write ops. + We don't want to override user write concern options if write concern + is already enabled. + """ + if self.safe and self.__write_concern.get('w') != 0: + return {} + return {'w': 1} + + def _get_write_mode(self, safe=None, **options): + """Get the current write mode. + + Determines if the current write is safe or not based on the + passed in or inherited safe value, write_concern values, or + passed options. + + :Parameters: + - `safe`: check that the operation succeeded? + - `**options`: overriding write concern options. + + .. versionadded:: 2.3 + """ + # Don't ever send w=1 to the server. + def pop1(dct): + if dct.get('w') == 1: + dct.pop('w') + return dct + + if safe is not None: + warnings.warn("The safe parameter is deprecated. Please use " + "write concern options instead.", DeprecationWarning, + stacklevel=3) + validate_boolean('safe', safe) + + # Passed options override collection level defaults. + if safe is not None or options: + if safe or options: + if not options: + options = self.__write_concern.copy() + # Backwards compatability edge case. Call getLastError + # with no options if safe=True was passed but collection + # level defaults have been disabled with w=0. + # These should be equivalent: + # Connection(w=0).foo.bar.insert({}, safe=True) + # MongoClient(w=0).foo.bar.insert({}, w=1) + if options.get('w') == 0: + return True, {} + # Passing w=0 overrides passing safe=True. + return options.get('w') != 0, pop1(options) + return False, {} + + # Fall back to collection level defaults. + # w=0 takes precedence over self.safe = True + if self.__write_concern.get('w') == 0: + return False, {} + elif self.safe or self.__write_concern.get('w', 0) != 0: + return True, pop1(self.__write_concern.copy()) + + return False, {} diff --git a/asyncio_mongo/_pymongo/connection.py b/asyncio_mongo/_pymongo/connection.py new file mode 100644 index 0000000..cfd230b --- /dev/null +++ b/asyncio_mongo/_pymongo/connection.py @@ -0,0 +1,231 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools for connecting to MongoDB. + +.. warning:: + **DEPRECATED:** Please use :mod:`~pymongo.mongo_client` instead. + +.. seealso:: Module :mod:`~pymongo.master_slave_connection` for + connecting to master-slave clusters, and + :doc:`/examples/high_availability` for an example of how to connect + to a replica set, or specify a list of mongos instances for automatic + failover. + +To get a :class:`~pymongo.database.Database` instance from a +:class:`Connection` use either dictionary-style or attribute-style +access: + +.. doctest:: + + >>> from asyncio_mongo._pymongo import Connection + >>> c = Connection() + >>> c.test_database + Database(Connection('localhost', 27017), u'test_database') + >>> c['test-database'] + Database(Connection('localhost', 27017), u'test-database') +""" +from asyncio_mongo._pymongo.mongo_client import MongoClient +from asyncio_mongo._pymongo.errors import ConfigurationError + + +class Connection(MongoClient): + """Connection to MongoDB. + """ + + def __init__(self, host=None, port=None, max_pool_size=None, + network_timeout=None, document_class=dict, + tz_aware=False, _connect=True, **kwargs): + """Create a new connection to a single MongoDB instance at *host:port*. + + .. warning:: + **DEPRECATED:** :class:`Connection` is deprecated. Please + use :class:`~pymongo.mongo_client.MongoClient` instead. + + The resultant connection object has connection-pooling built + in. It also performs auto-reconnection when necessary. If an + operation fails because of a connection error, + :class:`~pymongo.errors.ConnectionFailure` is raised. If + auto-reconnection will be performed, + :class:`~pymongo.errors.AutoReconnect` will be + raised. Application code should handle this exception + (recognizing that the operation failed) and then continue to + execute. + + Raises :class:`TypeError` if port is not an instance of + ``int``. Raises :class:`~pymongo.errors.ConnectionFailure` if + the connection cannot be made. + + The `host` parameter can be a full `mongodb URI + `_, in addition to + a simple hostname. It can also be a list of hostnames or + URIs. Any port specified in the host string(s) will override + the `port` parameter. If multiple mongodb URIs containing + database or auth information are passed, the last database, + username, and password present will be used. For username and + passwords reserved characters like ':', '/', '+' and '@' must be + escaped following RFC 2396. + + :Parameters: + - `host` (optional): hostname or IP address of the + instance to connect to, or a mongodb URI, or a list of + hostnames / mongodb URIs. If `host` is an IPv6 literal + it must be enclosed in '[' and ']' characters following + the RFC2732 URL syntax (e.g. '[::1]' for localhost) + - `port` (optional): port number on which to connect + - `max_pool_size` (optional): The maximum number of connections + that the pool will open simultaneously. If this is set, operations + will block if there are `max_pool_size` outstanding connections + from the pool. By default the pool size is unlimited. + - `network_timeout` (optional): timeout (in seconds) to use + for socket operations - default is no timeout + - `document_class` (optional): default class to use for + documents returned from queries on this connection + - `tz_aware` (optional): if ``True``, + :class:`~datetime.datetime` instances returned as values + in a document by this :class:`Connection` will be timezone + aware (otherwise they will be naive) + + | **Other optional parameters can be passed as keyword arguments:** + + - `socketTimeoutMS`: (integer) How long (in milliseconds) a send or + receive on a socket can take before timing out. + - `connectTimeoutMS`: (integer) How long (in milliseconds) a + connection can take to be opened before timing out. + - `auto_start_request`: If ``True`` (the default), each thread that + accesses this Connection has a socket allocated to it for the + thread's lifetime. This ensures consistent reads, even if you read + after an unsafe write. + - `use_greenlets`: if ``True``, :meth:`start_request()` will ensure + that the current greenlet uses the same socket for all operations + until :meth:`end_request()` + + | **Write Concern options:** + + - `safe`: :class:`Connection` **disables** acknowledgement of write + operations. Use ``safe=True`` to enable write acknowledgement. + - `w`: (integer or string) If this is a replica set, write operations + will block until they have been replicated to the specified number + or tagged set of servers. `w=` always includes the replica set + primary (e.g. w=3 means write to the primary and wait until + replicated to **two** secondaries). Implies safe=True. + - `wtimeout`: (integer) Used in conjunction with `w`. Specify a value + in milliseconds to control how long to wait for write propagation + to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. Implies safe=True. + - `j`: If ``True`` block until write operations have been committed + to the journal. Ignored if the server is running without journaling. + Implies safe=True. + - `fsync`: If ``True`` force the database to fsync all files before + returning. When used with `j` the server awaits the next group + commit before returning. Implies safe=True. + + | **Replica-set keyword arguments for connecting with a replica-set + - either directly or via a mongos:** + | (ignored by standalone mongod instances) + + - `slave_okay` or `slaveOk` (deprecated): Use `read_preference` + instead. + - `replicaSet`: (string) The name of the replica-set to connect to. + The driver will verify that the replica-set it connects to matches + this name. Implies that the hosts specified are a seed list and the + driver should attempt to find all members of the set. *Ignored by + mongos*. + - `read_preference`: The read preference for this client. If + connecting to a secondary then a read preference mode *other* than + PRIMARY is required - otherwise all queries will throw a + :class:`~pymongo.errors.AutoReconnect` "not master" error. + See :class:`~pymongo.read_preferences.ReadPreference` for all + available read preference options. + - `tag_sets`: Ignored unless connecting to a replica-set via mongos. + Specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags. + + | **SSL configuration:** + + - `ssl`: If ``True``, create the connection to the server using SSL. + - `ssl_keyfile`: The private keyfile used to identify the local + connection against mongod. If included with the ``certfile` then + only the ``ssl_certfile`` is needed. Implies ``ssl=True``. + - `ssl_certfile`: The certificate file used to identify the local + connection against mongod. Implies ``ssl=True``. + - `ssl_cert_reqs`: The parameter cert_reqs specifies whether a + certificate is required from the other side of the connection, + and whether it will be validated if provided. It must be one of the + three values ``ssl.CERT_NONE`` (certificates ignored), + ``ssl.CERT_OPTIONAL`` (not required, but validated if provided), or + ``ssl.CERT_REQUIRED`` (required and validated). If the value of + this parameter is not ``ssl.CERT_NONE``, then the ``ssl_ca_certs`` + parameter must point to a file of CA certificates. + Implies ``ssl=True``. + - `ssl_ca_certs`: The ca_certs file contains a set of concatenated + "certification authority" certificates, which are used to validate + certificates passed from the other end of the connection. + Implies ``ssl=True``. + + .. seealso:: :meth:`end_request` + .. versionchanged:: 2.5 + Added additional ssl options + .. versionchanged:: 2.3 + Added support for failover between mongos seed list members. + .. versionchanged:: 2.2 + Added `auto_start_request` option back. Added `use_greenlets` + option. + .. versionchanged:: 2.1 + Support `w` = integer or string. + Added `ssl` option. + DEPRECATED slave_okay/slaveOk. + .. versionchanged:: 2.0 + `slave_okay` is a pure keyword argument. Added support for safe, + and getlasterror options as keyword arguments. + .. versionchanged:: 1.11 + Added `max_pool_size`. Completely removed previously deprecated + `pool_size`, `auto_start_request` and `timeout` parameters. + .. versionchanged:: 1.8 + The `host` parameter can now be a full `mongodb URI + `_, in addition + to a simple hostname. It can also be a list of hostnames or + URIs. + .. versionadded:: 1.8 + The `tz_aware` parameter. + .. versionadded:: 1.7 + The `document_class` parameter. + .. versionadded:: 1.1 + The `network_timeout` parameter. + + .. mongodoc:: connections + """ + if network_timeout is not None: + if (not isinstance(network_timeout, (int, float)) or + network_timeout <= 0): + raise ConfigurationError("network_timeout must " + "be a positive integer") + kwargs['socketTimeoutMS'] = network_timeout * 1000 + + kwargs['auto_start_request'] = kwargs.get('auto_start_request', True) + kwargs['safe'] = kwargs.get('safe', False) + + super(Connection, self).__init__(host, port, + max_pool_size, document_class, tz_aware, _connect, **kwargs) + + def __repr__(self): + if len(self.nodes) == 1: + return "Connection(%r, %r)" % (self.host, self.port) + else: + return "Connection(%r)" % ["%s:%d" % n for n in self.nodes] + + def __next__(self): + raise TypeError("'Connection' object is not iterable") diff --git a/asyncio_mongo/_pymongo/cursor.py b/asyncio_mongo/_pymongo/cursor.py new file mode 100644 index 0000000..73d3c9a --- /dev/null +++ b/asyncio_mongo/_pymongo/cursor.py @@ -0,0 +1,963 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cursor class to iterate over Mongo query results.""" +import copy +from collections import deque + +from asyncio_mongo._bson import RE_TYPE +from asyncio_mongo._bson.code import Code +from asyncio_mongo._bson.son import SON +from asyncio_mongo._pymongo import helpers, message, read_preferences +from asyncio_mongo._pymongo.read_preferences import ReadPreference, secondary_ok_commands +from asyncio_mongo._pymongo.errors import (InvalidOperation, + AutoReconnect) + +_QUERY_OPTIONS = { + "tailable_cursor": 2, + "slave_okay": 4, + "oplog_replay": 8, + "no_timeout": 16, + "await_data": 32, + "exhaust": 64, + "partial": 128} + + +# This has to be an old style class due to +# http://bugs.jython.org/issue1057 +class _SocketManager: + """Used with exhaust cursors to ensure the socket is returned. + """ + def __init__(self, sock, pool): + self.sock = sock + self.pool = pool + self.__closed = False + + def __del__(self): + self.close() + + def close(self): + """Return this instance's socket to the connection pool. + """ + if not self.__closed: + self.__closed = True + self.pool.maybe_return_socket(self.sock) + self.sock, self.pool = None, None + + +# TODO might be cool to be able to do find().include("foo") or +# find().exclude(["bar", "baz"]) or find().slice("a", 1, 2) as an +# alternative to the fields specifier. +class Cursor(object): + """A cursor / iterator over Mongo query results. + """ + + def __init__(self, collection, spec=None, fields=None, skip=0, limit=0, + timeout=True, snapshot=False, tailable=False, sort=None, + max_scan=None, as_class=None, slave_okay=False, + await_data=False, partial=False, manipulate=True, + read_preference=ReadPreference.PRIMARY, tag_sets=[{}], + secondary_acceptable_latency_ms=None, exhaust=False, + _must_use_master=False, _uuid_subtype=None, + _first_batch=None, _cursor_id=None, + **kwargs): + """Create a new cursor. + + Should not be called directly by application developers - see + :meth:`~pymongo.collection.Collection.find` instead. + + .. mongodoc:: cursors + """ + self.__id = _cursor_id + self.__is_command_cursor = _cursor_id is not None + + if spec is None: + spec = {} + + if not isinstance(spec, dict): + raise TypeError("spec must be an instance of dict") + if not isinstance(skip, int): + raise TypeError("skip must be an instance of int") + if not isinstance(limit, int): + raise TypeError("limit must be an instance of int") + if not isinstance(timeout, bool): + raise TypeError("timeout must be an instance of bool") + if not isinstance(snapshot, bool): + raise TypeError("snapshot must be an instance of bool") + if not isinstance(tailable, bool): + raise TypeError("tailable must be an instance of bool") + if not isinstance(slave_okay, bool): + raise TypeError("slave_okay must be an instance of bool") + if not isinstance(await_data, bool): + raise TypeError("await_data must be an instance of bool") + if not isinstance(partial, bool): + raise TypeError("partial must be an instance of bool") + if not isinstance(exhaust, bool): + raise TypeError("exhaust must be an instance of bool") + + if fields is not None: + if not fields: + fields = {"_id": 1} + if not isinstance(fields, dict): + fields = helpers._fields_list_to_dict(fields) + + if as_class is None: + as_class = collection.database.connection.document_class + + self.__collection = collection + self.__spec = spec + self.__fields = fields + self.__skip = skip + self.__limit = limit + self.__batch_size = 0 + + # Exhaust cursor support + if self.__collection.database.connection.is_mongos and exhaust: + raise InvalidOperation('Exhaust cursors are ' + 'not supported by mongos') + if limit and exhaust: + raise InvalidOperation("Can't use limit and exhaust together.") + self.__exhaust = exhaust + self.__exhaust_mgr = None + + # This is ugly. People want to be able to do cursor[5:5] and + # get an empty result set (old behavior was an + # exception). It's hard to do that right, though, because the + # server uses limit(0) to mean 'no limit'. So we set __empty + # in that case and check for it when iterating. We also unset + # it anytime we change __limit. + self.__empty = False + + self.__snapshot = snapshot + self.__ordering = sort and helpers._index_document(sort) or None + self.__max_scan = max_scan + self.__explain = False + self.__hint = None + self.__as_class = as_class + self.__slave_okay = slave_okay + self.__manipulate = manipulate + self.__read_preference = read_preference + self.__tag_sets = tag_sets + self.__secondary_acceptable_latency_ms = secondary_acceptable_latency_ms + self.__tz_aware = collection.database.connection.tz_aware + self.__must_use_master = _must_use_master + self.__uuid_subtype = _uuid_subtype or collection.uuid_subtype + + self.__data = deque(_first_batch or []) + self.__connection_id = None + self.__retrieved = 0 + self.__killed = False + + self.__query_flags = 0 + if tailable: + self.__query_flags |= _QUERY_OPTIONS["tailable_cursor"] + if not timeout: + self.__query_flags |= _QUERY_OPTIONS["no_timeout"] + if tailable and await_data: + self.__query_flags |= _QUERY_OPTIONS["await_data"] + if exhaust: + self.__query_flags |= _QUERY_OPTIONS["exhaust"] + if partial: + self.__query_flags |= _QUERY_OPTIONS["partial"] + + # this is for passing network_timeout through if it's specified + # need to use kwargs as None is a legit value for network_timeout + self.__kwargs = kwargs + + @property + def collection(self): + """The :class:`~pymongo.collection.Collection` that this + :class:`Cursor` is iterating. + + .. versionadded:: 1.1 + """ + return self.__collection + + def __del__(self): + if self.__id and not self.__killed: + self.__die() + + def rewind(self): + """Rewind this cursor to its unevaluated state. + + Reset this cursor if it has been partially or completely evaluated. + Any options that are present on the cursor will remain in effect. + Future iterating performed on this cursor will cause new queries to + be sent to the server, even if the resultant data has already been + retrieved by this cursor. + """ + self.__check_not_command_cursor('rewind') + self.__data = deque() + self.__id = None + self.__connection_id = None + self.__retrieved = 0 + self.__killed = False + + return self + + def clone(self): + """Get a clone of this cursor. + + Returns a new Cursor instance with options matching those that have + been set on the current instance. The clone will be completely + unevaluated, even if the current instance has been partially or + completely evaluated. + """ + return self.__clone(True) + + def __clone(self, deepcopy=True): + self.__check_not_command_cursor('clone') + clone = Cursor(self.__collection) + values_to_clone = ("spec", "fields", "skip", "limit", + "snapshot", "ordering", "explain", "hint", + "batch_size", "max_scan", "as_class", "slave_okay", + "manipulate", "read_preference", "tag_sets", + "secondary_acceptable_latency_ms", + "must_use_master", "uuid_subtype", "query_flags", + "kwargs") + data = dict((k, v) for k, v in self.__dict__.items() + if k.startswith('_Cursor__') and k[9:] in values_to_clone) + if deepcopy: + data = self.__deepcopy(data) + clone.__dict__.update(data) + return clone + + def __die(self): + """Closes this cursor. + """ + if self.__id and not self.__killed: + if self.__exhaust and self.__exhaust_mgr: + # If this is an exhaust cursor and we haven't completely + # exhausted the result set we *must* close the socket + # to stop the server from sending more data. + self.__exhaust_mgr.sock.close() + else: + connection = self.__collection.database.connection + if self.__connection_id is not None: + connection.close_cursor(self.__id, self.__connection_id) + else: + connection.close_cursor(self.__id) + if self.__exhaust and self.__exhaust_mgr: + self.__exhaust_mgr.close() + self.__killed = True + + def close(self): + """Explicitly close / kill this cursor. Required for PyPy, Jython and + other Python implementations that don't use reference counting + garbage collection. + """ + self.__die() + + def __query_spec(self): + """Get the spec to use for a query. + """ + operators = {} + if self.__ordering: + operators["$orderby"] = self.__ordering + if self.__explain: + operators["$explain"] = True + if self.__hint: + operators["$hint"] = self.__hint + if self.__snapshot: + operators["$snapshot"] = True + if self.__max_scan: + operators["$maxScan"] = self.__max_scan + # Only set $readPreference if it's something other than + # PRIMARY to avoid problems with mongos versions that + # don't support read preferences. + if (self.__collection.database.connection.is_mongos and + self.__read_preference != ReadPreference.PRIMARY): + + has_tags = self.__tag_sets and self.__tag_sets != [{}] + + # For maximum backwards compatibility, don't set $readPreference + # for SECONDARY_PREFERRED unless tags are in use. Just rely on + # the slaveOkay bit (set automatically if read preference is not + # PRIMARY), which has the same behavior. + if (self.__read_preference != ReadPreference.SECONDARY_PREFERRED or + has_tags): + + read_pref = { + 'mode': read_preferences.mongos_mode(self.__read_preference) + } + if has_tags: + read_pref['tags'] = self.__tag_sets + + operators['$readPreference'] = read_pref + + if operators: + # Make a shallow copy so we can cleanly rewind or clone. + spec = self.__spec.copy() + + # Only commands that can be run on secondaries should have any + # operators added to the spec. Command queries can be issued + # by db.command or calling find_one on $cmd directly + if self.collection.name == "$cmd": + # Don't change commands that can't be sent to secondaries + command_name = spec and list(spec.keys())[0].lower() or "" + if command_name not in secondary_ok_commands: + return spec + elif command_name == 'mapreduce': + # mapreduce shouldn't be changed if its not inline + out = spec.get('out') + if not isinstance(out, dict) or not out.get('inline'): + return spec + + # White-listed commands must be wrapped in $query. + if "$query" not in spec: + # $query has to come first + spec = SON([("$query", spec)]) + + if not isinstance(spec, SON): + # Ensure the spec is SON. As order is important this will + # ensure its set before merging in any extra operators. + spec = SON(spec) + + spec.update(operators) + return spec + # Have to wrap with $query if "query" is the first key. + # We can't just use $query anytime "query" is a key as + # that breaks commands like count and find_and_modify. + # Checking spec.keys()[0] covers the case that the spec + # was passed as an instance of SON or OrderedDict. + elif ("query" in self.__spec and + (len(self.__spec) == 1 or list(self.__spec.keys())[0] == "query")): + return SON({"$query": self.__spec}) + + return self.__spec + + def __query_options(self): + """Get the query options string to use for this query. + """ + options = self.__query_flags + if (self.__slave_okay + or self.__read_preference != ReadPreference.PRIMARY + ): + options |= _QUERY_OPTIONS["slave_okay"] + return options + + def __check_okay_to_chain(self): + """Check if it is okay to chain more options onto this cursor. + """ + if self.__retrieved or self.__id is not None: + raise InvalidOperation("cannot set options after executing query") + + def __check_not_command_cursor(self, method_name): + """Check if calling a method on this cursor is valid. + """ + if self.__is_command_cursor: + raise InvalidOperation( + "cannot call %s on a command cursor" % method_name) + + def add_option(self, mask): + """Set arbitary query flags using a bitmask. + + To set the tailable flag: + cursor.add_option(2) + """ + if not isinstance(mask, int): + raise TypeError("mask must be an int") + self.__check_okay_to_chain() + + if mask & _QUERY_OPTIONS["slave_okay"]: + self.__slave_okay = True + if mask & _QUERY_OPTIONS["exhaust"]: + if self.__limit: + raise InvalidOperation("Can't use limit and exhaust together.") + if self.__collection.database.connection.is_mongos: + raise InvalidOperation('Exhaust cursors are ' + 'not supported by mongos') + self.__exhaust = True + + self.__query_flags |= mask + return self + + def remove_option(self, mask): + """Unset arbitrary query flags using a bitmask. + + To unset the tailable flag: + cursor.remove_option(2) + """ + if not isinstance(mask, int): + raise TypeError("mask must be an int") + self.__check_okay_to_chain() + + if mask & _QUERY_OPTIONS["slave_okay"]: + self.__slave_okay = False + if mask & _QUERY_OPTIONS["exhaust"]: + self.__exhaust = False + + self.__query_flags &= ~mask + return self + + def limit(self, limit): + """Limits the number of results to be returned by this cursor. + + Raises TypeError if limit is not an instance of int. Raises + InvalidOperation if this cursor has already been used. The + last `limit` applied to this cursor takes precedence. A limit + of ``0`` is equivalent to no limit. + + :Parameters: + - `limit`: the number of results to return + + .. mongodoc:: limit + """ + if not isinstance(limit, int): + raise TypeError("limit must be an int") + if self.__exhaust: + raise InvalidOperation("Can't use limit and exhaust together.") + self.__check_okay_to_chain() + + self.__empty = False + self.__limit = limit + return self + + def batch_size(self, batch_size): + """Limits the number of documents returned in one batch. Each batch + requires a round trip to the server. It can be adjusted to optimize + performance and limit data transfer. + + .. note:: batch_size can not override MongoDB's internal limits on the + amount of data it will return to the client in a single batch (i.e + if you set batch size to 1,000,000,000, MongoDB will currently only + return 4-16MB of results per batch). + + Raises :class:`TypeError` if `batch_size` is not an instance + of :class:`int`. Raises :class:`ValueError` if `batch_size` is + less than ``0``. Raises + :class:`~pymongo.errors.InvalidOperation` if this + :class:`Cursor` has already been used. The last `batch_size` + applied to this cursor takes precedence. + + :Parameters: + - `batch_size`: The size of each batch of results requested. + + .. versionadded:: 1.9 + """ + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an int") + if batch_size < 0: + raise ValueError("batch_size must be >= 0") + self.__check_okay_to_chain() + + self.__batch_size = batch_size == 1 and 2 or batch_size + return self + + def skip(self, skip): + """Skips the first `skip` results of this cursor. + + Raises TypeError if skip is not an instance of int. Raises + InvalidOperation if this cursor has already been used. The last `skip` + applied to this cursor takes precedence. + + :Parameters: + - `skip`: the number of results to skip + """ + if not isinstance(skip, int): + raise TypeError("skip must be an int") + self.__check_okay_to_chain() + + self.__skip = skip + return self + + def __getitem__(self, index): + """Get a single document or a slice of documents from this cursor. + + Raises :class:`~pymongo.errors.InvalidOperation` if this + cursor has already been used. + + To get a single document use an integral index, e.g.:: + + >>> db.test.find()[50] + + An :class:`IndexError` will be raised if the index is negative + or greater than the amount of documents in this cursor. Any + limit previously applied to this cursor will be ignored. + + To get a slice of documents use a slice index, e.g.:: + + >>> db.test.find()[20:25] + + This will return this cursor with a limit of ``5`` and skip of + ``20`` applied. Using a slice index will override any prior + limits or skips applied to this cursor (including those + applied through previous calls to this method). Raises + :class:`IndexError` when the slice has a step, a negative + start value, or a stop value less than or equal to the start + value. + + :Parameters: + - `index`: An integer or slice index to be applied to this cursor + """ + self.__check_okay_to_chain() + self.__empty = False + if isinstance(index, slice): + if index.step is not None: + raise IndexError("Cursor instances do not support slice steps") + + skip = 0 + if index.start is not None: + if index.start < 0: + raise IndexError("Cursor instances do not support" + "negative indices") + skip = index.start + + if index.stop is not None: + limit = index.stop - skip + if limit < 0: + raise IndexError("stop index must be greater than start" + "index for slice %r" % index) + if limit == 0: + self.__empty = True + else: + limit = 0 + + self.__skip = skip + self.__limit = limit + return self + + if isinstance(index, int): + if index < 0: + raise IndexError("Cursor instances do not support negative" + "indices") + clone = self.clone() + clone.skip(index + self.__skip) + clone.limit(-1) # use a hard limit + for doc in clone: + return doc + raise IndexError("no such item for Cursor instance") + raise TypeError("index %r cannot be applied to Cursor " + "instances" % index) + + def max_scan(self, max_scan): + """Limit the number of documents to scan when performing the query. + + Raises :class:`~pymongo.errors.InvalidOperation` if this + cursor has already been used. Only the last :meth:`max_scan` + applied to this cursor has any effect. + + :Parameters: + - `max_scan`: the maximum number of documents to scan + + .. note:: Requires server version **>= 1.5.1** + + .. versionadded:: 1.7 + """ + self.__check_okay_to_chain() + self.__max_scan = max_scan + return self + + def sort(self, key_or_list, direction=None): + """Sorts this cursor's results. + + Takes either a single key and a direction, or a list of (key, + direction) pairs. The key(s) must be an instance of ``(str, + unicode)``, and the direction(s) must be one of + (:data:`~pymongo.ASCENDING`, + :data:`~pymongo.DESCENDING`). Raises + :class:`~pymongo.errors.InvalidOperation` if this cursor has + already been used. Only the last :meth:`sort` applied to this + cursor has any effect. + + :Parameters: + - `key_or_list`: a single key or a list of (key, direction) + pairs specifying the keys to sort on + - `direction` (optional): only used if `key_or_list` is a single + key, if not given :data:`~pymongo.ASCENDING` is assumed + """ + self.__check_okay_to_chain() + keys = helpers._index_list(key_or_list, direction) + self.__ordering = helpers._index_document(keys) + return self + + def count(self, with_limit_and_skip=False): + """Get the size of the results set for this query. + + Returns the number of documents in the results set for this query. Does + not take :meth:`limit` and :meth:`skip` into account by default - set + `with_limit_and_skip` to ``True`` if that is the desired behavior. + Raises :class:`~pymongo.errors.OperationFailure` on a database error. + + With :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` + or :class:`~pymongo.master_slave_connection.MasterSlaveConnection`, + if `read_preference` is not + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY` or + :attr:`pymongo.read_preferences.ReadPreference.PRIMARY_PREFERRED`, or + (deprecated) `slave_okay` is `True`, the count command will be sent to + a secondary or slave. + + :Parameters: + - `with_limit_and_skip` (optional): take any :meth:`limit` or + :meth:`skip` that has been applied to this cursor into account when + getting the count + + .. note:: The `with_limit_and_skip` parameter requires server + version **>= 1.1.4-** + + .. note:: ``count`` ignores ``network_timeout``. For example, the + timeout is ignored in the following code:: + + collection.find({}, network_timeout=1).count() + + .. versionadded:: 1.1.1 + The `with_limit_and_skip` parameter. + :meth:`~pymongo.cursor.Cursor.__len__` was deprecated in favor of + calling :meth:`count` with `with_limit_and_skip` set to ``True``. + """ + self.__check_not_command_cursor('count') + command = {"query": self.__spec, "fields": self.__fields} + + command['read_preference'] = self.__read_preference + command['tag_sets'] = self.__tag_sets + command['secondary_acceptable_latency_ms'] = ( + self.__secondary_acceptable_latency_ms) + command['slave_okay'] = self.__slave_okay + use_master = not self.__slave_okay and not self.__read_preference + command['_use_master'] = use_master + + if with_limit_and_skip: + if self.__limit: + command["limit"] = self.__limit + if self.__skip: + command["skip"] = self.__skip + + database = self.__collection.database + r = database.command("count", self.__collection.name, + allowable_errors=["ns missing"], + uuid_subtype=self.__uuid_subtype, + **command) + if r.get("errmsg", "") == "ns missing": + return 0 + return int(r["n"]) + + def distinct(self, key): + """Get a list of distinct values for `key` among all documents + in the result set of this query. + + Raises :class:`TypeError` if `key` is not an instance of + :class:`basestring` (:class:`str` in python 3). + + With :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` + or :class:`~pymongo.master_slave_connection.MasterSlaveConnection`, + if `read_preference` is + not :attr:`pymongo.read_preferences.ReadPreference.PRIMARY` or + (deprecated) `slave_okay` is `True` the distinct command will be sent + to a secondary or slave. + + :Parameters: + - `key`: name of key for which we want to get the distinct values + + .. note:: Requires server version **>= 1.1.3+** + + .. seealso:: :meth:`pymongo.collection.Collection.distinct` + + .. versionadded:: 1.2 + """ + self.__check_not_command_cursor('distinct') + if not isinstance(key, str): + raise TypeError("key must be an instance " + "of %s" % (str.__name__,)) + + options = {"key": key} + if self.__spec: + options["query"] = self.__spec + + options['read_preference'] = self.__read_preference + options['tag_sets'] = self.__tag_sets + options['secondary_acceptable_latency_ms'] = ( + self.__secondary_acceptable_latency_ms) + options['slave_okay'] = self.__slave_okay + use_master = not self.__slave_okay and not self.__read_preference + options['_use_master'] = use_master + + database = self.__collection.database + return database.command("distinct", + self.__collection.name, + uuid_subtype=self.__uuid_subtype, + **options)["values"] + + def explain(self): + """Returns an explain plan record for this cursor. + + .. mongodoc:: explain + """ + self.__check_not_command_cursor('explain') + c = self.clone() + c.__explain = True + + # always use a hard limit for explains + if c.__limit: + c.__limit = -abs(c.__limit) + return next(c) + + def hint(self, index): + """Adds a 'hint', telling Mongo the proper index to use for the query. + + Judicious use of hints can greatly improve query + performance. When doing a query on multiple fields (at least + one of which is indexed) pass the indexed field as a hint to + the query. Hinting will not do anything if the corresponding + index does not exist. Raises + :class:`~pymongo.errors.InvalidOperation` if this cursor has + already been used. + + `index` should be an index as passed to + :meth:`~pymongo.collection.Collection.create_index` + (e.g. ``[('field', ASCENDING)]``). If `index` + is ``None`` any existing hints for this query are cleared. The + last hint applied to this cursor takes precedence over all + others. + + :Parameters: + - `index`: index to hint on (as an index specifier) + """ + self.__check_okay_to_chain() + if index is None: + self.__hint = None + return self + + self.__hint = helpers._index_document(index) + return self + + def where(self, code): + """Adds a $where clause to this query. + + The `code` argument must be an instance of :class:`basestring` + (:class:`str` in python 3) or :class:`~bson.code.Code` + containing a JavaScript expression. This expression will be + evaluated for each document scanned. Only those documents + for which the expression evaluates to *true* will be returned + as results. The keyword *this* refers to the object currently + being scanned. + + Raises :class:`TypeError` if `code` is not an instance of + :class:`basestring` (:class:`str` in python 3). Raises + :class:`~pymongo.errors.InvalidOperation` if this + :class:`Cursor` has already been used. Only the last call to + :meth:`where` applied to a :class:`Cursor` has any effect. + + :Parameters: + - `code`: JavaScript expression to use as a filter + """ + self.__check_okay_to_chain() + if not isinstance(code, Code): + code = Code(code) + + self.__spec["$where"] = code + return self + + def __send_message(self, message): + """Send a query or getmore message and handles the response. + + If message is ``None`` this is an exhaust cursor, which reads + the next result batch off the exhaust socket instead of + sending getMore messages to the server. + """ + client = self.__collection.database.connection + + if message: + kwargs = {"_must_use_master": self.__must_use_master} + kwargs["read_preference"] = self.__read_preference + kwargs["tag_sets"] = self.__tag_sets + kwargs["secondary_acceptable_latency_ms"] = ( + self.__secondary_acceptable_latency_ms) + kwargs['exhaust'] = self.__exhaust + if self.__connection_id is not None: + kwargs["_connection_to_use"] = self.__connection_id + kwargs.update(self.__kwargs) + + try: + res = client._send_message_with_response(message, **kwargs) + self.__connection_id, (response, sock, pool) = res + if self.__exhaust: + self.__exhaust_mgr = _SocketManager(sock, pool) + except AutoReconnect: + # Don't try to send kill cursors on another socket + # or to another server. It can cause a _pinValue + # assertion on some server releases if we get here + # due to a socket timeout. + self.__killed = True + raise + else: # exhaust cursor - no getMore message + response = client._exhaust_next(self.__exhaust_mgr.sock) + + try: + response = helpers._unpack_response(response, self.__id, + self.__as_class, + self.__tz_aware, + self.__uuid_subtype) + except AutoReconnect: + # Don't send kill cursors to another server after a "not master" + # error. It's completely pointless. + self.__killed = True + client.disconnect() + raise + self.__id = response["cursor_id"] + + # starting from doesn't get set on getmore's for tailable cursors + if not (self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]): + assert response["starting_from"] == self.__retrieved, ( + "Result batch started from %s, expected %s" % ( + response['starting_from'], self.__retrieved)) + + self.__retrieved += response["number_returned"] + self.__data = deque(response["data"]) + + if self.__limit and self.__id and self.__limit <= self.__retrieved: + self.__die() + + # Don't wait for garbage collection to call __del__, return the + # socket to the pool now. + if self.__exhaust and self.__id == 0: + self.__exhaust_mgr.close() + + def _refresh(self): + """Refreshes the cursor with more data from Mongo. + + Returns the length of self.__data after refresh. Will exit early if + self.__data is already non-empty. Raises OperationFailure when the + cursor cannot be refreshed due to an error on the query. + """ + if len(self.__data) or self.__killed: + return len(self.__data) + + if self.__id is None: # Query + ntoreturn = self.__batch_size + if self.__limit: + if self.__batch_size: + ntoreturn = min(self.__limit, self.__batch_size) + else: + ntoreturn = self.__limit + self.__send_message( + message.query(self.__query_options(), + self.__collection.full_name, + self.__skip, ntoreturn, + self.__query_spec(), self.__fields, + self.__uuid_subtype)) + if not self.__id: + self.__killed = True + elif self.__id: # Get More + if self.__limit: + limit = self.__limit - self.__retrieved + if self.__batch_size: + limit = min(limit, self.__batch_size) + else: + limit = self.__batch_size + + # Exhaust cursors don't send getMore messages. + if self.__exhaust: + self.__send_message(None) + else: + self.__send_message( + message.get_more(self.__collection.full_name, + limit, self.__id)) + + else: # Cursor id is zero nothing else to return + self.__killed = True + + return len(self.__data) + + @property + def alive(self): + """Does this cursor have the potential to return more data? + + This is mostly useful with `tailable cursors + `_ + since they will stop iterating even though they *may* return more + results in the future. + + .. versionadded:: 1.5 + """ + return bool(len(self.__data) or (not self.__killed)) + + @property + def cursor_id(self): + """Returns the id of the cursor + + Useful if you need to manage cursor ids and want to handle killing + cursors manually using + :meth:`~pymongo.mongo_client.MongoClient.kill_cursors` + + .. versionadded:: 2.2 + """ + return self.__id + + def __iter__(self): + return self + + def __next__(self): + if self.__empty: + raise StopIteration + db = self.__collection.database + if len(self.__data) or self._refresh(): + if self.__manipulate: + return db._fix_outgoing(self.__data.popleft(), + self.__collection) + else: + return self.__data.popleft() + else: + raise StopIteration + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.__die() + + def __copy__(self): + """Support function for `copy.copy()`. + + .. versionadded:: 2.4 + """ + return self.__clone(deepcopy=False) + + def __deepcopy__(self, memo): + """Support function for `copy.deepcopy()`. + + .. versionadded:: 2.4 + """ + return self.__clone(deepcopy=True) + + def __deepcopy(self, x, memo=None): + """Deepcopy helper for the data dictionary or list. + + Regular expressions cannot be deep copied but as they are immutable we + don't have to copy them when cloning. + """ + if not hasattr(x, 'items'): + y, is_list, iterator = [], True, enumerate(x) + else: + y, is_list, iterator = {}, False, iter(x.items()) + + if memo is None: + memo = {} + val_id = id(x) + if val_id in memo: + return memo.get(val_id) + memo[val_id] = y + + for key, value in iterator: + if isinstance(value, (dict, list)) and not isinstance(value, SON): + value = self.__deepcopy(value, memo) + elif not isinstance(value, RE_TYPE): + value = copy.deepcopy(value, memo) + + if is_list: + y.append(value) + else: + if not isinstance(key, RE_TYPE): + key = copy.deepcopy(key, memo) + y[key] = value + return y diff --git a/asyncio_mongo/_pymongo/cursor_manager.py b/asyncio_mongo/_pymongo/cursor_manager.py new file mode 100644 index 0000000..39432ce --- /dev/null +++ b/asyncio_mongo/_pymongo/cursor_manager.py @@ -0,0 +1,93 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DEPRECATED - Different managers to handle when cursors are killed after +they are closed. + +New cursor managers should be defined as subclasses of CursorManager and can be +installed on a connection by calling +`pymongo.connection.Connection.set_cursor_manager`. + +.. versionchanged:: 2.1+ + Deprecated. +""" + +import weakref + + +class CursorManager(object): + """The default cursor manager. + + This manager will kill cursors one at a time as they are closed. + """ + + def __init__(self, connection): + """Instantiate the manager. + + :Parameters: + - `connection`: a Mongo Connection + """ + self.__connection = weakref.ref(connection) + + def close(self, cursor_id): + """Close a cursor by killing it immediately. + + Raises TypeError if cursor_id is not an instance of (int, long). + + :Parameters: + - `cursor_id`: cursor id to close + """ + if not isinstance(cursor_id, int): + raise TypeError("cursor_id must be an instance of (int, long)") + + self.__connection().kill_cursors([cursor_id]) + + +class BatchCursorManager(CursorManager): + """A cursor manager that kills cursors in batches. + """ + + def __init__(self, connection): + """Instantiate the manager. + + :Parameters: + - `connection`: a Mongo Connection + """ + self.__dying_cursors = [] + self.__max_dying_cursors = 20 + self.__connection = weakref.ref(connection) + + CursorManager.__init__(self, connection) + + def __del__(self): + """Cleanup - be sure to kill any outstanding cursors. + """ + self.__connection().kill_cursors(self.__dying_cursors) + + def close(self, cursor_id): + """Close a cursor by killing it in a batch. + + Raises TypeError if cursor_id is not an instance of (int, long). + + :Parameters: + - `cursor_id`: cursor id to close + """ + if not isinstance(cursor_id, int): + raise TypeError("cursor_id must be an instance of (int, long)") + + self.__dying_cursors.append(cursor_id) + + if len(self.__dying_cursors) > self.__max_dying_cursors: + self.__connection().kill_cursors(self.__dying_cursors) + self.__dying_cursors = [] diff --git a/asyncio_mongo/_pymongo/database.py b/asyncio_mongo/_pymongo/database.py new file mode 100644 index 0000000..db9f4ff --- /dev/null +++ b/asyncio_mongo/_pymongo/database.py @@ -0,0 +1,875 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database level operations.""" + +from asyncio_mongo._bson.binary import OLD_UUID_SUBTYPE +from asyncio_mongo._bson.code import Code +from asyncio_mongo._bson.dbref import DBRef +from asyncio_mongo._bson.son import SON +from asyncio_mongo._pymongo import auth, common, helpers +from asyncio_mongo._pymongo.collection import Collection +from asyncio_mongo._pymongo.errors import (CollectionInvalid, + InvalidName, + OperationFailure) +from asyncio_mongo._pymongo.son_manipulator import ObjectIdInjector +from asyncio_mongo._pymongo import read_preferences as rp + + +def _check_name(name): + """Check if a database name is valid. + """ + if not name: + raise InvalidName("database name cannot be the empty string") + + for invalid_char in [" ", ".", "$", "/", "\\", "\x00"]: + if invalid_char in name: + raise InvalidName("database names cannot contain the " + "character %r" % invalid_char) + + +class Database(common.BaseObject): + """A Mongo database. + """ + + def __init__(self, connection, name): + """Get a database by connection and name. + + Raises :class:`TypeError` if `name` is not an instance of + :class:`basestring` (:class:`str` in python 3). Raises + :class:`~pymongo.errors.InvalidName` if `name` is not a valid + database name. + + :Parameters: + - `connection`: a client instance + - `name`: database name + + .. mongodoc:: databases + """ + super(Database, + self).__init__(slave_okay=connection.slave_okay, + read_preference=connection.read_preference, + tag_sets=connection.tag_sets, + secondary_acceptable_latency_ms=( + connection.secondary_acceptable_latency_ms), + safe=connection.safe, + **connection.write_concern) + + if not isinstance(name, str): + raise TypeError("name must be an instance " + "of %s" % (str.__name__,)) + + if name != '$external': + _check_name(name) + + self.__name = str(name) + self.__connection = connection + + self.__incoming_manipulators = [] + self.__incoming_copying_manipulators = [] + self.__outgoing_manipulators = [] + self.__outgoing_copying_manipulators = [] + self.add_son_manipulator(ObjectIdInjector()) + + def add_son_manipulator(self, manipulator): + """Add a new son manipulator to this database. + + Newly added manipulators will be applied before existing ones. + + :Parameters: + - `manipulator`: the manipulator to add + """ + def method_overwritten(instance, method): + return getattr(instance, method) != \ + getattr(super(instance.__class__, instance), method) + + if manipulator.will_copy(): + if method_overwritten(manipulator, "transform_incoming"): + self.__incoming_copying_manipulators.insert(0, manipulator) + if method_overwritten(manipulator, "transform_outgoing"): + self.__outgoing_copying_manipulators.insert(0, manipulator) + else: + if method_overwritten(manipulator, "transform_incoming"): + self.__incoming_manipulators.insert(0, manipulator) + if method_overwritten(manipulator, "transform_outgoing"): + self.__outgoing_manipulators.insert(0, manipulator) + + @property + def system_js(self): + """A :class:`SystemJS` helper for this :class:`Database`. + + See the documentation for :class:`SystemJS` for more details. + + .. versionadded:: 1.5 + """ + return SystemJS(self) + + @property + def connection(self): + """The client instance for this :class:`Database`. + + .. versionchanged:: 1.3 + ``connection`` is now a property rather than a method. + """ + return self.__connection + + @property + def name(self): + """The name of this :class:`Database`. + + .. versionchanged:: 1.3 + ``name`` is now a property rather than a method. + """ + return self.__name + + @property + def incoming_manipulators(self): + """List all incoming SON manipulators + installed on this instance. + + .. versionadded:: 2.0 + """ + return [manipulator.__class__.__name__ + for manipulator in self.__incoming_manipulators] + + @property + def incoming_copying_manipulators(self): + """List all incoming SON copying manipulators + installed on this instance. + + .. versionadded:: 2.0 + """ + return [manipulator.__class__.__name__ + for manipulator in self.__incoming_copying_manipulators] + + @property + def outgoing_manipulators(self): + """List all outgoing SON manipulators + installed on this instance. + + .. versionadded:: 2.0 + """ + return [manipulator.__class__.__name__ + for manipulator in self.__outgoing_manipulators] + + @property + def outgoing_copying_manipulators(self): + """List all outgoing SON copying manipulators + installed on this instance. + + .. versionadded:: 2.0 + """ + return [manipulator.__class__.__name__ + for manipulator in self.__outgoing_copying_manipulators] + + def __eq__(self, other): + if isinstance(other, Database): + us = (self.__connection, self.__name) + them = (other.__connection, other.__name) + return us == them + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "Database(%r, %r)" % (self.__connection, self.__name) + + def __getattr__(self, name): + """Get a collection of this database by name. + + Raises InvalidName if an invalid collection name is used. + + :Parameters: + - `name`: the name of the collection to get + """ + return Collection(self, name) + + def __getitem__(self, name): + """Get a collection of this database by name. + + Raises InvalidName if an invalid collection name is used. + + :Parameters: + - `name`: the name of the collection to get + """ + return self.__getattr__(name) + + def create_collection(self, name, **kwargs): + """Create a new :class:`~pymongo.collection.Collection` in this + database. + + Normally collection creation is automatic. This method should + only be used to specify options on + creation. :class:`~pymongo.errors.CollectionInvalid` will be + raised if the collection already exists. + + Options should be passed as keyword arguments to this + method. Any of the following options are valid: + + - "size": desired initial size for the collection (in + bytes). For capped collections this size is the max + size of the collection. + - "capped": if True, this is a capped collection + - "max": maximum number of objects if capped (optional) + + :Parameters: + - `name`: the name of the collection to create + - `**kwargs` (optional): additional keyword arguments will + be passed as options for the create collection command + + .. versionchanged:: 2.2 + Removed deprecated argument: options + + .. versionchanged:: 1.5 + deprecating `options` in favor of kwargs + """ + opts = {"create": True} + opts.update(kwargs) + + if name in self.collection_names(): + raise CollectionInvalid("collection %s already exists" % name) + + return Collection(self, name, **opts) + + def _fix_incoming(self, son, collection): + """Apply manipulators to an incoming SON object before it gets stored. + + :Parameters: + - `son`: the son object going into the database + - `collection`: the collection the son object is being saved in + """ + for manipulator in self.__incoming_manipulators: + son = manipulator.transform_incoming(son, collection) + for manipulator in self.__incoming_copying_manipulators: + son = manipulator.transform_incoming(son, collection) + return son + + def _fix_outgoing(self, son, collection): + """Apply manipulators to a SON object as it comes out of the database. + + :Parameters: + - `son`: the son object coming out of the database + - `collection`: the collection the son object was saved in + """ + for manipulator in reversed(self.__outgoing_manipulators): + son = manipulator.transform_outgoing(son, collection) + for manipulator in reversed(self.__outgoing_copying_manipulators): + son = manipulator.transform_outgoing(son, collection) + return son + + def command(self, command, value=1, + check=True, allowable_errors=[], + uuid_subtype=OLD_UUID_SUBTYPE, **kwargs): + """Issue a MongoDB command. + + Send command `command` to the database and return the + response. If `command` is an instance of :class:`basestring` + (:class:`str` in python 3) then the command {`command`: `value`} + will be sent. Otherwise, `command` must be an instance of + :class:`dict` and will be sent as is. + + Any additional keyword arguments will be added to the final + command document before it is sent. + + For example, a command like ``{buildinfo: 1}`` can be sent + using: + + >>> db.command("buildinfo") + + For a command where the value matters, like ``{collstats: + collection_name}`` we can do: + + >>> db.command("collstats", collection_name) + + For commands that take additional arguments we can use + kwargs. So ``{filemd5: object_id, root: file_root}`` becomes: + + >>> db.command("filemd5", object_id, root=file_root) + + :Parameters: + - `command`: document representing the command to be issued, + or the name of the command (for simple commands only). + + .. note:: the order of keys in the `command` document is + significant (the "verb" must come first), so commands + which require multiple keys (e.g. `findandmodify`) + should use an instance of :class:`~bson.son.SON` or + a string and kwargs instead of a Python `dict`. + + - `value` (optional): value to use for the command verb when + `command` is passed as a string + - `check` (optional): check the response for errors, raising + :class:`~pymongo.errors.OperationFailure` if there are any + - `allowable_errors`: if `check` is ``True``, error messages + in this list will be ignored by error-checking + - `uuid_subtype` (optional): The BSON binary subtype to use + for a UUID used in this command. + - `read_preference`: The read preference for this connection. + See :class:`~pymongo.read_preferences.ReadPreference` for available + options. + - `tag_sets`: Read from replica-set members with these tags. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." ReplicaSetConnection tries each set of tags in turn + until it finds a set of tags with at least one matching member. + - `secondary_acceptable_latency_ms`: Any replica-set member whose + ping time is within secondary_acceptable_latency_ms of the nearest + member may accept reads. Default 15 milliseconds. + **Ignored by mongos** and must be configured on the command line. + See the localThreshold_ option for more information. + - `**kwargs` (optional): additional keyword arguments will + be added to the command document before it is sent + + .. note:: ``command`` ignores the ``network_timeout`` parameter. + + .. versionchanged:: 2.3 + Added `tag_sets` and `secondary_acceptable_latency_ms` options. + .. versionchanged:: 2.2 + Added support for `as_class` - the class you want to use for + the resulting documents + .. versionchanged:: 1.6 + Added the `value` argument for string commands, and keyword + arguments for additional command options. + .. versionchanged:: 1.5 + `command` can be a string in addition to a full document. + .. versionadded:: 1.4 + + .. mongodoc:: commands + .. _localThreshold: http://docs.mongodb.org/manual/reference/mongos/#cmdoption-mongos--localThreshold + """ + + if isinstance(command, str): + command = SON([(command, value)]) + + command_name = list(command.keys())[0].lower() + must_use_master = kwargs.pop('_use_master', False) + if command_name not in rp.secondary_ok_commands: + must_use_master = True + + # Special-case: mapreduce can go to secondaries only if inline + if command_name == 'mapreduce': + out = command.get('out') or kwargs.get('out') + if not isinstance(out, dict) or not out.get('inline'): + must_use_master = True + + extra_opts = { + 'as_class': kwargs.pop('as_class', None), + 'slave_okay': kwargs.pop('slave_okay', self.slave_okay), + '_must_use_master': must_use_master, + '_uuid_subtype': uuid_subtype + } + + extra_opts['read_preference'] = kwargs.pop( + 'read_preference', + self.read_preference) + extra_opts['tag_sets'] = kwargs.pop( + 'tag_sets', + self.tag_sets) + extra_opts['secondary_acceptable_latency_ms'] = kwargs.pop( + 'secondary_acceptable_latency_ms', + self.secondary_acceptable_latency_ms) + + fields = kwargs.get('fields') + if fields is not None and not isinstance(fields, dict): + kwargs['fields'] = helpers._fields_list_to_dict(fields) + + command.update(kwargs) + + result = self["$cmd"].find_one(command, **extra_opts) + + if check: + msg = "command %s failed: %%s" % repr(command).replace("%", "%%") + helpers._check_command_response(result, self.connection.disconnect, + msg, allowable_errors) + + return result + + def collection_names(self, include_system_collections=True): + """Get a list of all the collection names in this database. + + :Parameters: + - `include_system_collections` (optional): if ``False`` list + will not include system collections (e.g ``system.indexes``) + """ + results = self["system.namespaces"].find(_must_use_master=True) + names = [r["name"] for r in results] + names = [n[len(self.__name) + 1:] for n in names + if n.startswith(self.__name + ".") and "$" not in n] + if not include_system_collections: + names = [n for n in names if not n.startswith("system.")] + return names + + def drop_collection(self, name_or_collection): + """Drop a collection. + + :Parameters: + - `name_or_collection`: the name of a collection to drop or the + collection object itself + """ + name = name_or_collection + if isinstance(name, Collection): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_collection must be an instance of " + "%s or Collection" % (str.__name__,)) + + self.__connection._purge_index(self.__name, name) + + self.command("drop", str(name), allowable_errors=["ns not found"]) + + def validate_collection(self, name_or_collection, + scandata=False, full=False): + """Validate a collection. + + Returns a dict of validation info. Raises CollectionInvalid if + validation fails. + + With MongoDB < 1.9 the result dict will include a `result` key + with a string value that represents the validation results. With + MongoDB >= 1.9 the `result` key no longer exists and the results + are split into individual fields in the result dict. + + :Parameters: + - `name_or_collection`: A Collection object or the name of a + collection to validate. + - `scandata`: Do extra checks beyond checking the overall + structure of the collection. + - `full`: Have the server do a more thorough scan of the + collection. Use with `scandata` for a thorough scan + of the structure of the collection and the individual + documents. Ignored in MongoDB versions before 1.9. + + .. versionchanged:: 1.11 + validate_collection previously returned a string. + .. versionadded:: 1.11 + Added `scandata` and `full` options. + """ + name = name_or_collection + if isinstance(name, Collection): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_collection must be an instance of " + "%s or Collection" % (str.__name__,)) + + result = self.command("validate", str(name), + scandata=scandata, full=full) + + valid = True + # Pre 1.9 results + if "result" in result: + info = result["result"] + if info.find("exception") != -1 or info.find("corrupt") != -1: + raise CollectionInvalid("%s invalid: %s" % (name, info)) + # Sharded results + elif "raw" in result: + for _, res in result["raw"].items(): + if "result" in res: + info = res["result"] + if (info.find("exception") != -1 or + info.find("corrupt") != -1): + raise CollectionInvalid("%s invalid: " + "%s" % (name, info)) + elif not res.get("valid", False): + valid = False + break + # Post 1.9 non-sharded results. + elif not result.get("valid", False): + valid = False + + if not valid: + raise CollectionInvalid("%s invalid: %r" % (name, result)) + + return result + + def current_op(self, include_all=False): + """Get information on operations currently running. + + :Parameters: + - `include_all` (optional): if ``True`` also list currently + idle operations in the result + """ + if include_all: + return self['$cmd.sys.inprog'].find_one({"$all": True}) + else: + return self['$cmd.sys.inprog'].find_one() + + def profiling_level(self): + """Get the database's current profiling level. + + Returns one of (:data:`~pymongo.OFF`, + :data:`~pymongo.SLOW_ONLY`, :data:`~pymongo.ALL`). + + .. mongodoc:: profiling + """ + result = self.command("profile", -1) + + assert result["was"] >= 0 and result["was"] <= 2 + return result["was"] + + def set_profiling_level(self, level, slow_ms=None): + """Set the database's profiling level. + + :Parameters: + - `level`: Specifies a profiling level, see list of possible values + below. + - `slow_ms`: Optionally modify the threshold for the profile to + consider a query or operation. Even if the profiler is off queries + slower than the `slow_ms` level will get written to the logs. + + Possible `level` values: + + +----------------------------+------------------------------------+ + | Level | Setting | + +============================+====================================+ + | :data:`~pymongo.OFF` | Off. No profiling. | + +----------------------------+------------------------------------+ + | :data:`~pymongo.SLOW_ONLY` | On. Only includes slow operations. | + +----------------------------+------------------------------------+ + | :data:`~pymongo.ALL` | On. Includes all operations. | + +----------------------------+------------------------------------+ + + Raises :class:`ValueError` if level is not one of + (:data:`~pymongo.OFF`, :data:`~pymongo.SLOW_ONLY`, + :data:`~pymongo.ALL`). + + .. mongodoc:: profiling + """ + if not isinstance(level, int) or level < 0 or level > 2: + raise ValueError("level must be one of (OFF, SLOW_ONLY, ALL)") + + if slow_ms is not None and not isinstance(slow_ms, int): + raise TypeError("slow_ms must be an integer") + + if slow_ms is not None: + self.command("profile", level, slowms=slow_ms) + else: + self.command("profile", level) + + def profiling_info(self): + """Returns a list containing current profiling information. + + .. mongodoc:: profiling + """ + return list(self["system.profile"].find()) + + def error(self): + """Get a database error if one occured on the last operation. + + Return None if the last operation was error-free. Otherwise return the + error that occurred. + """ + error = self.command("getlasterror") + error_msg = error.get("err", "") + if error_msg is None: + return None + if error_msg.startswith("not master"): + self.__connection.disconnect() + return error + + def last_status(self): + """Get status information from the last operation. + + Returns a SON object with status information. + """ + return self.command("getlasterror") + + def previous_error(self): + """Get the most recent error to have occurred on this database. + + Only returns errors that have occurred since the last call to + `Database.reset_error_history`. Returns None if no such errors have + occurred. + """ + error = self.command("getpreverror") + if error.get("err", 0) is None: + return None + return error + + def reset_error_history(self): + """Reset the error history of this database. + + Calls to `Database.previous_error` will only return errors that have + occurred since the most recent call to this method. + """ + self.command("reseterror") + + def __iter__(self): + return self + + def __next__(self): + raise TypeError("'Database' object is not iterable") + + def add_user(self, name, password=None, read_only=None, **kwargs): + """Create user `name` with password `password`. + + Add a new user with permissions for this :class:`Database`. + + .. note:: Will change the password if user `name` already exists. + + :Parameters: + - `name`: the name of the user to create + - `password` (optional): the password of the user to create. Can not + be used with the ``userSource`` argument. + - `read_only` (optional): if ``True`` the user will be read only + - `**kwargs` (optional): optional fields for the user document + (e.g. ``userSource``, ``otherDBRoles``, or ``roles``). See + ``_ + for more information. + + .. note:: The use of optional keyword arguments like ``userSource``, + ``otherDBRoles``, or ``roles`` requires MongoDB >= 2.4.0 + + .. versionchanged:: 2.5 + Added kwargs support for optional fields introduced in MongoDB 2.4 + + .. versionchanged:: 2.2 + Added support for read only users + + .. versionadded:: 1.4 + """ + + user = self.system.users.find_one({"user": name}) or {"user": name} + if password is not None: + user["pwd"] = auth._password_digest(name, password) + if read_only is not None: + user["readOnly"] = common.validate_boolean('read_only', read_only) + user.update(kwargs) + + try: + self.system.users.save(user, **self._get_wc_override()) + except OperationFailure as e: + # First admin user add fails gle in MongoDB >= 2.1.2 + # See SERVER-4225 for more information. + if 'login' in str(e): + pass + else: + raise + + def remove_user(self, name): + """Remove user `name` from this :class:`Database`. + + User `name` will no longer have permissions to access this + :class:`Database`. + + :Parameters: + - `name`: the name of the user to remove + + .. versionadded:: 1.4 + """ + self.system.users.remove({"user": name}, **self._get_wc_override()) + + def authenticate(self, name, password=None, + source=None, mechanism='MONGODB-CR', **kwargs): + """Authenticate to use this database. + + Authentication lasts for the life of the underlying client + instance, or until :meth:`logout` is called. + + Raises :class:`TypeError` if (required) `name`, (optional) `password`, + or (optional) `source` is not an instance of :class:`basestring` + (:class:`str` in python 3). + + .. note:: + - This method authenticates the current connection, and + will also cause all new :class:`~socket.socket` connections + in the underlying client instance to be authenticated automatically. + + - Authenticating more than once on the same database with different + credentials is not supported. You must call :meth:`logout` before + authenticating with new credentials. + + - When sharing a client instance between multiple threads, all + threads will share the authentication. If you need different + authentication profiles for different purposes you must use + distinct client instances. + + - To get authentication to apply immediately to all + existing sockets you may need to reset this client instance's + sockets using :meth:`~pymongo.mongo_client.MongoClient.disconnect`. + + :Parameters: + - `name`: the name of the user to authenticate. + - `password` (optional): the password of the user to authenticate. + Not used with GSSAPI or MONGODB-X509 authentication. + - `source` (optional): the database to authenticate on. If not + specified the current database is used. + - `mechanism` (optional): See + :data:`~pymongo.auth.MECHANISMS` for options. + Defaults to MONGODB-CR (MongoDB Challenge Response protocol) + - `gssapiServiceName` (optional): Used with the GSSAPI mechanism + to specify the service name portion of the service principal name. + Defaults to 'mongodb'. + + .. versionchanged:: 2.5 + Added the `source` and `mechanism` parameters. :meth:`authenticate` + now raises a subclass of :class:`~pymongo.errors.PyMongoError` if + authentication fails due to invalid credentials or configuration + issues. + + .. mongodoc:: authenticate + """ + if not isinstance(name, str): + raise TypeError("name must be an instance " + "of %s" % (str.__name__,)) + if password is not None and not isinstance(password, str): + raise TypeError("password must be an instance " + "of %s" % (str.__name__,)) + if source is not None and not isinstance(source, str): + raise TypeError("source must be an instance " + "of %s" % (str.__name__,)) + common.validate_auth_mechanism('mechanism', mechanism) + + validated_options = {} + for option, value in kwargs.items(): + normalized, val = common.validate_auth_option(option, value) + validated_options[normalized] = val + + credentials = auth._build_credentials_tuple(mechanism, + source or self.name, str(name), + password and str(password) or None, + validated_options) + self.connection._cache_credentials(self.name, credentials) + return True + + def logout(self): + """Deauthorize use of this database for this client instance. + + .. note:: Other databases may still be authenticated, and other + existing :class:`~socket.socket` connections may remain + authenticated for this database unless you reset all sockets + with :meth:`~pymongo.mongo_client.MongoClient.disconnect`. + """ + # Sockets will be deauthenticated as they are used. + self.connection._purge_credentials(self.name) + + def dereference(self, dbref): + """Dereference a :class:`~bson.dbref.DBRef`, getting the + document it points to. + + Raises :class:`TypeError` if `dbref` is not an instance of + :class:`~bson.dbref.DBRef`. Returns a document, or ``None`` if + the reference does not point to a valid document. Raises + :class:`ValueError` if `dbref` has a database specified that + is different from the current database. + + :Parameters: + - `dbref`: the reference + """ + if not isinstance(dbref, DBRef): + raise TypeError("cannot dereference a %s" % type(dbref)) + if dbref.database is not None and dbref.database != self.__name: + raise ValueError("trying to dereference a DBRef that points to " + "another database (%r not %r)" % (dbref.database, + self.__name)) + return self[dbref.collection].find_one({"_id": dbref.id}) + + def eval(self, code, *args): + """Evaluate a JavaScript expression in MongoDB. + + Useful if you need to touch a lot of data lightly; in such a + scenario the network transfer of the data could be a + bottleneck. The `code` argument must be a JavaScript + function. Additional positional arguments will be passed to + that function when it is run on the server. + + Raises :class:`TypeError` if `code` is not an instance of + :class:`basestring` (:class:`str` in python 3) or `Code`. + Raises :class:`~pymongo.errors.OperationFailure` if the eval + fails. Returns the result of the evaluation. + + :Parameters: + - `code`: string representation of JavaScript code to be + evaluated + - `args` (optional): additional positional arguments are + passed to the `code` being evaluated + """ + if not isinstance(code, Code): + code = Code(code) + + result = self.command("$eval", code, args=args) + return result.get("retval", None) + + def __call__(self, *args, **kwargs): + """This is only here so that some API misusages are easier to debug. + """ + raise TypeError("'Database' object is not callable. If you meant to " + "call the '%s' method on a '%s' object it is " + "failing because no such method exists." % ( + self.__name, self.__connection.__class__.__name__)) + + +class SystemJS(object): + """Helper class for dealing with stored JavaScript. + """ + + def __init__(self, database): + """Get a system js helper for the database `database`. + + An instance of :class:`SystemJS` can be created with an instance + of :class:`Database` through :attr:`Database.system_js`, + manual instantiation of this class should not be necessary. + + :class:`SystemJS` instances allow for easy manipulation and + access to server-side JavaScript: + + .. doctest:: + + >>> db.system_js.add1 = "function (x) { return x + 1; }" + >>> db.system.js.find({"_id": "add1"}).count() + 1 + >>> db.system_js.add1(5) + 6.0 + >>> del db.system_js.add1 + >>> db.system.js.find({"_id": "add1"}).count() + 0 + + .. note:: Requires server version **>= 1.1.1** + + .. versionadded:: 1.5 + """ + # can't just assign it since we've overridden __setattr__ + object.__setattr__(self, "_db", database) + + def __setattr__(self, name, code): + self._db.system.js.save({"_id": name, "value": Code(code)}, + **self._db._get_wc_override()) + + def __setitem__(self, name, code): + self.__setattr__(name, code) + + def __delattr__(self, name): + self._db.system.js.remove({"_id": name}, **self._db._get_wc_override()) + + def __delitem__(self, name): + self.__delattr__(name) + + def __getattr__(self, name): + return lambda *args: self._db.eval(Code("function() { " + "return this[name].apply(" + "this, arguments); }", + scope={'name': name}), *args) + + def __getitem__(self, name): + return self.__getattr__(name) + + def list(self): + """Get a list of the names of the functions stored in this database. + + .. versionadded:: 1.9 + """ + return [x["_id"] for x in self._db.system.js.find(fields=["_id"])] diff --git a/asyncio_mongo/_pymongo/errors.py b/asyncio_mongo/_pymongo/errors.py new file mode 100644 index 0000000..8413ba3 --- /dev/null +++ b/asyncio_mongo/_pymongo/errors.py @@ -0,0 +1,121 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Exceptions raised by PyMongo.""" + +from asyncio_mongo._bson.errors import * + +try: + from ssl import CertificateError +except ImportError: + from asyncio_mongo._pymongo.ssl_match_hostname import CertificateError + + +class PyMongoError(Exception): + """Base class for all PyMongo exceptions. + + .. versionadded:: 1.4 + """ + + +class ConnectionFailure(PyMongoError): + """Raised when a connection to the database cannot be made or is lost. + """ + + +class AutoReconnect(ConnectionFailure): + """Raised when a connection to the database is lost and an attempt to + auto-reconnect will be made. + + In order to auto-reconnect you must handle this exception, recognizing that + the operation which caused it has not necessarily succeeded. Future + operations will attempt to open a new connection to the database (and + will continue to raise this exception until the first successful + connection is made). + """ + def __init__(self, message='', errors=None): + self.errors = errors or [] + ConnectionFailure.__init__(self, message) + + +class ConfigurationError(PyMongoError): + """Raised when something is incorrectly configured. + """ + + +class OperationFailure(PyMongoError): + """Raised when a database operation fails. + + .. versionadded:: 1.8 + The :attr:`code` attribute. + """ + + def __init__(self, error, code=None): + self.code = code + PyMongoError.__init__(self, error) + + +class TimeoutError(OperationFailure): + """Raised when a database operation times out. + + .. versionadded:: 1.8 + """ + + +class DuplicateKeyError(OperationFailure): + """Raised when a safe insert or update fails due to a duplicate key error. + + .. note:: Requires server version **>= 1.3.0** + + .. versionadded:: 1.4 + """ + + +class InvalidOperation(PyMongoError): + """Raised when a client attempts to perform an invalid operation. + """ + + +class InvalidName(PyMongoError): + """Raised when an invalid name is used. + """ + + +class CollectionInvalid(PyMongoError): + """Raised when collection validation fails. + """ + + +class InvalidURI(ConfigurationError): + """Raised when trying to parse an invalid mongodb URI. + + .. versionadded:: 1.5 + """ + + +class UnsupportedOption(ConfigurationError): + """Exception for unsupported options. + + .. versionadded:: 2.0 + """ + + +class ExceededMaxWaiters(Exception): + """Raised when a thread tries to get a connection from a pool and + ``max_pool_size * waitQueueMultiple`` threads are already waiting. + + .. versionadded:: 2.6 + """ + pass + diff --git a/asyncio_mongo/_pymongo/helpers.py b/asyncio_mongo/_pymongo/helpers.py new file mode 100644 index 0000000..a2b2d63 --- /dev/null +++ b/asyncio_mongo/_pymongo/helpers.py @@ -0,0 +1,174 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bits and pieces used by the driver that don't really fit elsewhere.""" + +import random +import struct + +import asyncio_mongo._bson as bson +import asyncio_mongo._pymongo + +from asyncio_mongo._bson.binary import OLD_UUID_SUBTYPE +from asyncio_mongo._bson.son import SON +from asyncio_mongo._pymongo.errors import (AutoReconnect, + DuplicateKeyError, + OperationFailure, + TimeoutError) + + +def _index_list(key_or_list, direction=None): + """Helper to generate a list of (key, direction) pairs. + + Takes such a list, or a single key, or a single key and direction. + """ + if direction is not None: + return [(key_or_list, direction)] + else: + if isinstance(key_or_list, str): + return [(key_or_list, pymongo.ASCENDING)] + elif not isinstance(key_or_list, list): + raise TypeError("if no direction is specified, " + "key_or_list must be an instance of list") + return key_or_list + + +def _index_document(index_list): + """Helper to generate an index specifying document. + + Takes a list of (key, direction) pairs. + """ + if isinstance(index_list, dict): + raise TypeError("passing a dict to sort/create_index/hint is not " + "allowed - use a list of tuples instead. did you " + "mean %r?" % list(index_list.items())) + elif not isinstance(index_list, list): + raise TypeError("must use a list of (key, direction) pairs, " + "not: " + repr(index_list)) + if not len(index_list): + raise ValueError("key_or_list must not be the empty list") + + index = SON() + for (key, value) in index_list: + if not isinstance(key, str): + raise TypeError("first item in each key pair must be a string") + if not isinstance(value, (str, int)): + raise TypeError("second item in each key pair must be 1, -1, " + "'2d', 'geoHaystack', or another valid MongoDB " + "index specifier.") + index[key] = value + return index + + +def _unpack_response(response, cursor_id=None, as_class=dict, + tz_aware=False, uuid_subtype=OLD_UUID_SUBTYPE): + """Unpack a response from the database. + + Check the response for errors and unpack, returning a dictionary + containing the response data. + + :Parameters: + - `response`: byte string as returned from the database + - `cursor_id` (optional): cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response + - `as_class` (optional): class to use for resulting documents + """ + response_flag = struct.unpack("= 1") + + for slave in slaves: + if not isinstance(slave, MongoClient): + raise TypeError("slave %r is not an instance of MongoClient" % + slave) + + super(MasterSlaveConnection, + self).__init__(read_preference=ReadPreference.SECONDARY, + safe=master.safe, + **master.write_concern) + + self.__master = master + self.__slaves = slaves + self.__document_class = document_class + self.__tz_aware = tz_aware + self.__request_counter = thread_util.Counter(master.use_greenlets) + + @property + def master(self): + return self.__master + + @property + def slaves(self): + return self.__slaves + + @property + def is_mongos(self): + """If this MasterSlaveConnection is connected to mongos (always False) + + .. versionadded:: 2.3 + """ + return False + + @property + def use_greenlets(self): + """Whether calling :meth:`start_request` assigns greenlet-local, + rather than thread-local, sockets. + + .. versionadded:: 2.4.2 + """ + return self.master.use_greenlets + + def get_document_class(self): + return self.__document_class + + def set_document_class(self, klass): + self.__document_class = klass + + document_class = property(get_document_class, set_document_class, + doc="""Default class to use for documents + returned on this connection.""") + + @property + def tz_aware(self): + return self.__tz_aware + + @property + def max_bson_size(self): + """Return the maximum size BSON object the connected master + accepts in bytes. Defaults to 4MB in server < 1.7.4. + + .. versionadded:: 2.6 + """ + return self.master.max_bson_size + + @property + def max_message_size(self): + """Return the maximum message size the connected master + accepts in bytes. + + .. versionadded:: 2.6 + """ + return self.master.max_message_size + + + def disconnect(self): + """Disconnect from MongoDB. + + Disconnecting will call disconnect on all master and slave + connections. + + .. seealso:: Module :mod:`~pymongo.mongo_client` + .. versionadded:: 1.10.1 + """ + self.__master.disconnect() + for slave in self.__slaves: + slave.disconnect() + + def set_cursor_manager(self, manager_class): + """Set the cursor manager for this connection. + + Helper to set cursor manager for each individual `MongoClient` instance + that make up this `MasterSlaveConnection`. + """ + self.__master.set_cursor_manager(manager_class) + for slave in self.__slaves: + slave.set_cursor_manager(manager_class) + + def _ensure_connected(self, sync): + """Ensure the master is connected to a mongod/s. + """ + self.__master._ensure_connected(sync) + + # _connection_to_use is a hack that we need to include to make sure + # that killcursor operations can be sent to the same instance on which + # the cursor actually resides... + def _send_message(self, message, + with_last_error=False, _connection_to_use=None): + """Say something to Mongo. + + Sends a message on the Master connection. This is used for inserts, + updates, and deletes. + + Raises ConnectionFailure if the message cannot be sent. Returns the + request id of the sent message. + + :Parameters: + - `operation`: opcode of the message + - `data`: data to send + - `safe`: perform a getLastError after sending the message + """ + if _connection_to_use is None or _connection_to_use == -1: + return self.__master._send_message(message, with_last_error) + return self.__slaves[_connection_to_use]._send_message( + message, with_last_error, check_primary=False) + + # _connection_to_use is a hack that we need to include to make sure + # that getmore operations can be sent to the same instance on which + # the cursor actually resides... + def _send_message_with_response(self, message, _connection_to_use=None, + _must_use_master=False, **kwargs): + """Receive a message from Mongo. + + Sends the given message and returns a (connection_id, response) pair. + + :Parameters: + - `operation`: opcode of the message to send + - `data`: data to send + """ + if _connection_to_use is not None: + if _connection_to_use == -1: + member = self.__master + conn = -1 + else: + member = self.__slaves[_connection_to_use] + conn = _connection_to_use + return (conn, + member._send_message_with_response(message, **kwargs)[1]) + + # _must_use_master is set for commands, which must be sent to the + # master instance. any queries in a request must be sent to the + # master since that is where writes go. + if _must_use_master or self.in_request(): + return (-1, self.__master._send_message_with_response(message, + **kwargs)[1]) + + # Iterate through the slaves randomly until we have success. Raise + # reconnect if they all fail. + for connection_id in helpers.shuffled(range(len(self.__slaves))): + try: + slave = self.__slaves[connection_id] + return (connection_id, + slave._send_message_with_response(message, + **kwargs)[1]) + except AutoReconnect: + pass + + raise AutoReconnect("failed to connect to slaves") + + def start_request(self): + """Start a "request". + + Start a sequence of operations in which order matters. Note + that all operations performed within a request will be sent + using the Master connection. + """ + self.__request_counter.inc() + self.master.start_request() + + def in_request(self): + return bool(self.__request_counter.get()) + + def end_request(self): + """End the current "request". + + See documentation for `MongoClient.end_request`. + """ + self.__request_counter.dec() + self.master.end_request() + + def __eq__(self, other): + if isinstance(other, MasterSlaveConnection): + us = (self.__master, self.slaves) + them = (other.__master, other.__slaves) + return us == them + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __repr__(self): + return "MasterSlaveConnection(%r, %r)" % (self.__master, self.__slaves) + + def __getattr__(self, name): + """Get a database by name. + + Raises InvalidName if an invalid database name is used. + + :Parameters: + - `name`: the name of the database to get + """ + return Database(self, name) + + def __getitem__(self, name): + """Get a database by name. + + Raises InvalidName if an invalid database name is used. + + :Parameters: + - `name`: the name of the database to get + """ + return self.__getattr__(name) + + def close_cursor(self, cursor_id, connection_id): + """Close a single database cursor. + + Raises TypeError if cursor_id is not an instance of (int, long). What + closing the cursor actually means depends on this connection's cursor + manager. + + :Parameters: + - `cursor_id`: cursor id to close + - `connection_id`: id of the `MongoClient` instance where the cursor + was opened + """ + if connection_id == -1: + return self.__master.close_cursor(cursor_id) + return self.__slaves[connection_id].close_cursor(cursor_id) + + def database_names(self): + """Get a list of all database names. + """ + return self.__master.database_names() + + def drop_database(self, name_or_database): + """Drop a database. + + :Parameters: + - `name_or_database`: the name of a database to drop or the object + itself + """ + return self.__master.drop_database(name_or_database) + + def __iter__(self): + return self + + def __next__(self): + raise TypeError("'MasterSlaveConnection' object is not iterable") + + def _cached(self, database_name, collection_name, index_name): + return self.__master._cached(database_name, + collection_name, index_name) + + def _cache_index(self, database_name, collection_name, + index_name, cache_for): + return self.__master._cache_index(database_name, collection_name, + index_name, cache_for) + + def _purge_index(self, database_name, + collection_name=None, index_name=None): + return self.__master._purge_index(database_name, + collection_name, + index_name) diff --git a/asyncio_mongo/_pymongo/message.py b/asyncio_mongo/_pymongo/message.py new file mode 100644 index 0000000..ed7bbd3 --- /dev/null +++ b/asyncio_mongo/_pymongo/message.py @@ -0,0 +1,254 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for creating `messages +`_ to be sent to +MongoDB. + +.. note:: This module is for internal use and is generally not needed by + application developers. + +.. versionadded:: 1.1.2 +""" + +import random +import struct + +import asyncio_mongo._bson as bson +from asyncio_mongo._bson.binary import OLD_UUID_SUBTYPE +from asyncio_mongo._bson.py3compat import b +from asyncio_mongo._bson.son import SON +try: + from asyncio_mongo._pymongo import _cmessage + _use_c = True +except ImportError: + _use_c = False +from asyncio_mongo._pymongo.errors import InvalidDocument, InvalidOperation, OperationFailure + + +__ZERO = b("\x00\x00\x00\x00") + +EMPTY = b("") + +MAX_INT32 = 2147483647 +MIN_INT32 = -2147483648 + + +def __last_error(namespace, args): + """Data to send to do a lastError. + """ + cmd = SON([("getlasterror", 1)]) + cmd.update(args) + splitns = namespace.split('.', 1) + return query(0, splitns[0] + '.$cmd', 0, -1, cmd) + + +def __pack_message(operation, data): + """Takes message data and adds a message header based on the operation. + + Returns the resultant message string. + """ + request_id = random.randint(MIN_INT32, MAX_INT32) + message = struct.pack(" client.max_bson_size: + raise InvalidDocument("BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % + (encoded_length, client.max_bson_size)) + message_length += encoded_length + if message_length < client.max_message_size: + data.append(encoded) + continue + + # We have enough data, send this message. + send_safe = safe or not continue_on_error + try: + client._send_message(_insert_message(EMPTY.join(data), + send_safe), send_safe) + # Exception type could be OperationFailure or a subtype + # (e.g. DuplicateKeyError) + except OperationFailure as exc: + # Like it says, continue on error... + if continue_on_error: + # Store exception details to re-raise after the final batch. + last_error = exc + # With unacknowledged writes just return at the first error. + elif not safe: + return + # With acknowledged writes raise immediately. + else: + raise + message_length = len(begin) + encoded_length + data = [begin, encoded] + + client._send_message(_insert_message(EMPTY.join(data), safe), safe) + + # Re-raise any exception stored due to continue_on_error + if last_error is not None: + raise last_error +if _use_c: + _do_batched_insert = _cmessage._do_batched_insert diff --git a/asyncio_mongo/_pymongo/mongo_client.py b/asyncio_mongo/_pymongo/mongo_client.py new file mode 100644 index 0000000..bf00f71 --- /dev/null +++ b/asyncio_mongo/_pymongo/mongo_client.py @@ -0,0 +1,1338 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools for connecting to MongoDB. + +.. seealso:: Module :mod:`~pymongo.master_slave_connection` for + connecting to master-slave clusters, and + :doc:`/examples/high_availability` for an example of how to connect + to a replica set, or specify a list of mongos instances for automatic + failover. + +To get a :class:`~pymongo.database.Database` instance from a +:class:`MongoClient` use either dictionary-style or attribute-style +access: + +.. doctest:: + + >>> from asyncio_mongo._pymongo import MongoClient + >>> c = MongoClient() + >>> c.test_database + Database(MongoClient('localhost', 27017), u'test_database') + >>> c['test-database'] + Database(MongoClient('localhost', 27017), u'test-database') +""" + +import datetime +import random +import socket +import struct +import time +import warnings + +from asyncio_mongo._bson.py3compat import b +from asyncio_mongo._pymongo import (auth, + common, + database, + helpers, + message, + pool, + uri_parser) +from asyncio_mongo._pymongo.common import HAS_SSL +from asyncio_mongo._pymongo.cursor_manager import CursorManager +from asyncio_mongo._pymongo.errors import (AutoReconnect, + ConfigurationError, + ConnectionFailure, + DuplicateKeyError, + InvalidDocument, + InvalidURI, + OperationFailure) + +EMPTY = b("") + + +def _partition_node(node): + """Split a host:port string returned from mongod/s into + a (host, int(port)) pair needed for socket.connect(). + """ + host = node + port = 27017 + idx = node.rfind(':') + if idx != -1: + host, port = node[:idx], int(node[idx + 1:]) + if host.startswith('['): + host = host[1:-1] + return host, port + + +class MongoClient(common.BaseObject): + """Connection to MongoDB. + """ + + HOST = "localhost" + PORT = 27017 + + __max_bson_size = 4 * 1024 * 1024 + + def __init__(self, host=None, port=None, max_pool_size=100, + document_class=dict, tz_aware=False, _connect=True, + **kwargs): + """Create a new connection to a single MongoDB instance at *host:port*. + + The resultant client object has connection-pooling built + in. It also performs auto-reconnection when necessary. If an + operation fails because of a connection error, + :class:`~pymongo.errors.ConnectionFailure` is raised. If + auto-reconnection will be performed, + :class:`~pymongo.errors.AutoReconnect` will be + raised. Application code should handle this exception + (recognizing that the operation failed) and then continue to + execute. + + Raises :class:`TypeError` if port is not an instance of + ``int``. Raises :class:`~pymongo.errors.ConnectionFailure` if + the connection cannot be made. + + The `host` parameter can be a full `mongodb URI + `_, in addition to + a simple hostname. It can also be a list of hostnames or + URIs. Any port specified in the host string(s) will override + the `port` parameter. If multiple mongodb URIs containing + database or auth information are passed, the last database, + username, and password present will be used. For username and + passwords reserved characters like ':', '/', '+' and '@' must be + escaped following RFC 2396. + + :Parameters: + - `host` (optional): hostname or IP address of the + instance to connect to, or a mongodb URI, or a list of + hostnames / mongodb URIs. If `host` is an IPv6 literal + it must be enclosed in '[' and ']' characters following + the RFC2732 URL syntax (e.g. '[::1]' for localhost) + - `port` (optional): port number on which to connect + - `max_pool_size` (optional): The maximum number of connections + that the pool will open simultaneously. If this is set, operations + will block if there are `max_pool_size` outstanding connections + from the pool. Defaults to 100. + - `document_class` (optional): default class to use for + documents returned from queries on this client + - `tz_aware` (optional): if ``True``, + :class:`~datetime.datetime` instances returned as values + in a document by this :class:`MongoClient` will be timezone + aware (otherwise they will be naive) + + | **Other optional parameters can be passed as keyword arguments:** + + - `socketTimeoutMS`: (integer) How long (in milliseconds) a send or + receive on a socket can take before timing out. + - `connectTimeoutMS`: (integer) How long (in milliseconds) a + connection can take to be opened before timing out. + - `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a + thread will wait for a socket from the pool if the pool has no + free sockets. + - `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give + the number of threads allowed to wait for a socket at one time. + - `auto_start_request`: If ``True``, each thread that accesses + this :class:`MongoClient` has a socket allocated to it for the + thread's lifetime. This ensures consistent reads, even if you + read after an unacknowledged write. Defaults to ``False`` + - `use_greenlets`: If ``True``, :meth:`start_request()` will ensure + that the current greenlet uses the same socket for all + operations until :meth:`end_request()` + + | **Write Concern options:** + + - `w`: (integer or string) If this is a replica set, write operations + will block until they have been replicated to the specified number + or tagged set of servers. `w=` always includes the replica set + primary (e.g. w=3 means write to the primary and wait until + replicated to **two** secondaries). Passing w=0 **disables write + acknowledgement** and all other write concern options. + - `wtimeout`: (integer) Used in conjunction with `w`. Specify a value + in milliseconds to control how long to wait for write propagation + to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. + - `j`: If ``True`` block until write operations have been committed + to the journal. Ignored if the server is running without journaling. + - `fsync`: If ``True`` force the database to fsync all files before + returning. When used with `j` the server awaits the next group + commit before returning. + + | **Replica set keyword arguments for connecting with a replica set + - either directly or via a mongos:** + | (ignored by standalone mongod instances) + + - `replicaSet`: (string) The name of the replica set to connect to. + The driver will verify that the replica set it connects to matches + this name. Implies that the hosts specified are a seed list and the + driver should attempt to find all members of the set. *Ignored by + mongos*. + - `read_preference`: The read preference for this client. If + connecting to a secondary then a read preference mode *other* than + PRIMARY is required - otherwise all queries will throw + :class:`~pymongo.errors.AutoReconnect` "not master". + See :class:`~pymongo.read_preferences.ReadPreference` for all + available read preference options. + - `tag_sets`: Ignored unless connecting to a replica set via mongos. + Specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags. + + | **SSL configuration:** + + - `ssl`: If ``True``, create the connection to the server using SSL. + - `ssl_keyfile`: The private keyfile used to identify the local + connection against mongod. If included with the ``certfile` then + only the ``ssl_certfile`` is needed. Implies ``ssl=True``. + - `ssl_certfile`: The certificate file used to identify the local + connection against mongod. Implies ``ssl=True``. + - `ssl_cert_reqs`: Specifies whether a certificate is required from + the other side of the connection, and whether it will be validated + if provided. It must be one of the three values ``ssl.CERT_NONE`` + (certificates ignored), ``ssl.CERT_OPTIONAL`` + (not required, but validated if provided), or ``ssl.CERT_REQUIRED`` + (required and validated). If the value of this parameter is not + ``ssl.CERT_NONE``, then the ``ssl_ca_certs`` parameter must point + to a file of CA certificates. Implies ``ssl=True``. + - `ssl_ca_certs`: The ca_certs file contains a set of concatenated + "certification authority" certificates, which are used to validate + certificates passed from the other end of the connection. + Implies ``ssl=True``. + + .. seealso:: :meth:`end_request` + + .. mongodoc:: connections + + .. versionchanged:: 2.5 + Added additional ssl options + .. versionadded:: 2.4 + """ + if host is None: + host = self.HOST + if isinstance(host, str): + host = [host] + if port is None: + port = self.PORT + if not isinstance(port, int): + raise TypeError("port must be an instance of int") + + seeds = set() + username = None + password = None + self.__default_database_name = None + opts = {} + for entity in host: + if "://" in entity: + if entity.startswith("mongodb://"): + res = uri_parser.parse_uri(entity, port) + seeds.update(res["nodelist"]) + username = res["username"] or username + password = res["password"] or password + self.__default_database_name = ( + res["database"] or self.__default_database_name) + + opts = res["options"] + else: + idx = entity.find("://") + raise InvalidURI("Invalid URI scheme: " + "%s" % (entity[:idx],)) + else: + seeds.update(uri_parser.split_hosts(entity, port)) + if not seeds: + raise ConfigurationError("need to specify at least one host") + + self.__nodes = seeds + self.__host = None + self.__port = None + self.__is_primary = False + self.__is_mongos = False + + # _pool_class option is for deep customization of PyMongo, e.g. Motor. + # SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO 10GEN. + pool_class = kwargs.pop('_pool_class', pool.Pool) + + options = {} + for option, value in kwargs.items(): + option, value = common.validate(option, value) + options[option] = value + options.update(opts) + + self.__max_pool_size = common.validate_positive_integer_or_none( + 'max_pool_size', max_pool_size) + + self.__cursor_manager = CursorManager(self) + + self.__repl = options.get('replicaset') + if len(seeds) == 1 and not self.__repl: + self.__direct = True + else: + self.__direct = False + self.__nodes = set() + + self.__net_timeout = options.get('sockettimeoutms') + self.__conn_timeout = options.get('connecttimeoutms') + self.__wait_queue_timeout = options.get('waitqueuetimeoutms') + self.__wait_queue_multiple = options.get('waitqueuemultiple') + + self.__use_ssl = options.get('ssl', None) + self.__ssl_keyfile = options.get('ssl_keyfile', None) + self.__ssl_certfile = options.get('ssl_certfile', None) + self.__ssl_cert_reqs = options.get('ssl_cert_reqs', None) + self.__ssl_ca_certs = options.get('ssl_ca_certs', None) + + ssl_kwarg_keys = [k for k in list(kwargs.keys()) if k.startswith('ssl_')] + if self.__use_ssl == False and ssl_kwarg_keys: + raise ConfigurationError("ssl has not been enabled but the " + "following ssl parameters have been set: " + "%s. Please set `ssl=True` or remove." + % ', '.join(ssl_kwarg_keys)) + + if self.__ssl_cert_reqs and not self.__ssl_ca_certs: + raise ConfigurationError("If `ssl_cert_reqs` is not " + "`ssl.CERT_NONE` then you must " + "include `ssl_ca_certs` to be able " + "to validate the server.") + + if ssl_kwarg_keys and self.__use_ssl is None: + # ssl options imply ssl = True + self.__use_ssl = True + + if self.__use_ssl and not HAS_SSL: + raise ConfigurationError("The ssl module is not available. If you " + "are using a python version previous to " + "2.6 you must install the ssl package " + "from PyPI.") + + self.__use_greenlets = options.get('use_greenlets', False) + self.__pool = pool_class( + None, + self.__max_pool_size, + self.__net_timeout, + self.__conn_timeout, + self.__use_ssl, + use_greenlets=self.__use_greenlets, + ssl_keyfile=self.__ssl_keyfile, + ssl_certfile=self.__ssl_certfile, + ssl_cert_reqs=self.__ssl_cert_reqs, + ssl_ca_certs=self.__ssl_ca_certs, + wait_queue_timeout=self.__wait_queue_timeout, + wait_queue_multiple=self.__wait_queue_multiple) + + self.__document_class = document_class + self.__tz_aware = common.validate_boolean('tz_aware', tz_aware) + self.__auto_start_request = options.get('auto_start_request', False) + + # cache of existing indexes used by ensure_index ops + self.__index_cache = {} + self.__auth_credentials = {} + + super(MongoClient, self).__init__(**options) + if self.slave_okay: + warnings.warn("slave_okay is deprecated. Please " + "use read_preference instead.", DeprecationWarning, + stacklevel=2) + + if _connect: + try: + self.__find_node(seeds) + except AutoReconnect as e: + # ConnectionFailure makes more sense here than AutoReconnect + raise ConnectionFailure(str(e)) + + if username: + mechanism = options.get('authmechanism', 'MONGODB-CR') + source = ( + options.get('authsource') + or self.__default_database_name + or 'admin') + + credentials = auth._build_credentials_tuple(mechanism, + source, + str(username), + str(password), + options) + try: + self._cache_credentials(source, credentials, _connect) + except OperationFailure as exc: + raise ConfigurationError(str(exc)) + + def _cached(self, dbname, coll, index): + """Test if `index` is cached. + """ + cache = self.__index_cache + now = datetime.datetime.utcnow() + return (dbname in cache and + coll in cache[dbname] and + index in cache[dbname][coll] and + now < cache[dbname][coll][index]) + + def _cache_index(self, database, collection, index, cache_for): + """Add an index to the index cache for ensure_index operations. + """ + now = datetime.datetime.utcnow() + expire = datetime.timedelta(seconds=cache_for) + now + + if database not in self.__index_cache: + self.__index_cache[database] = {} + self.__index_cache[database][collection] = {} + self.__index_cache[database][collection][index] = expire + + elif collection not in self.__index_cache[database]: + self.__index_cache[database][collection] = {} + self.__index_cache[database][collection][index] = expire + + else: + self.__index_cache[database][collection][index] = expire + + def _purge_index(self, database_name, + collection_name=None, index_name=None): + """Purge an index from the index cache. + + If `index_name` is None purge an entire collection. + + If `collection_name` is None purge an entire database. + """ + if not database_name in self.__index_cache: + return + + if collection_name is None: + del self.__index_cache[database_name] + return + + if not collection_name in self.__index_cache[database_name]: + return + + if index_name is None: + del self.__index_cache[database_name][collection_name] + return + + if index_name in self.__index_cache[database_name][collection_name]: + del self.__index_cache[database_name][collection_name][index_name] + + def _cache_credentials(self, source, credentials, connect=True): + """Add credentials to the database authentication cache + for automatic login when a socket is created. If `connect` is True, + verify the credentials on the server first. + """ + if source in self.__auth_credentials: + # Nothing to do if we already have these credentials. + if credentials == self.__auth_credentials[source]: + return + raise OperationFailure('Another user is already authenticated ' + 'to this database. You must logout first.') + + if connect: + sock_info = self.__socket() + try: + # Since __check_auth was called in __socket + # there is no need to call it here. + auth.authenticate(credentials, sock_info, self.__simple_command) + sock_info.authset.add(credentials) + finally: + self.__pool.maybe_return_socket(sock_info) + + self.__auth_credentials[source] = credentials + + def _purge_credentials(self, source): + """Purge credentials from the database authentication cache. + """ + if source in self.__auth_credentials: + del self.__auth_credentials[source] + + def __check_auth(self, sock_info): + """Authenticate using cached database credentials. + """ + if self.__auth_credentials or sock_info.authset: + cached = set(self.__auth_credentials.values()) + + authset = sock_info.authset.copy() + + # Logout any credentials that no longer exist in the cache. + for credentials in authset - cached: + self.__simple_command(sock_info, credentials[1], {'logout': 1}) + sock_info.authset.discard(credentials) + + for credentials in cached - authset: + auth.authenticate(credentials, + sock_info, self.__simple_command) + sock_info.authset.add(credentials) + + @property + def host(self): + """Current connected host. + + .. versionchanged:: 1.3 + ``host`` is now a property rather than a method. + """ + return self.__host + + @property + def port(self): + """Current connected port. + + .. versionchanged:: 1.3 + ``port`` is now a property rather than a method. + """ + return self.__port + + @property + def is_primary(self): + """If this instance is connected to a standalone, a replica set + primary, or the master of a master-slave set. + + .. versionadded:: 2.3 + """ + return self.__is_primary + + @property + def is_mongos(self): + """If this instance is connected to mongos. + + .. versionadded:: 2.3 + """ + return self.__is_mongos + + @property + def max_pool_size(self): + """The maximum number of sockets the pool will open concurrently. + + When the pool has reached `max_pool_size`, operations block waiting for + a socket to be returned to the pool. If ``waitQueueTimeoutMS`` is set, + a blocked operation will raise :exc:`~pymongo.errors.ConnectionFailure` + after a timeout. By default ``waitQueueTimeoutMS`` is not set. + + .. warning:: SIGNIFICANT BEHAVIOR CHANGE in 2.6. Previously, this + parameter would limit only the idle sockets the pool would hold + onto, not the number of open sockets. The default has also changed + to 100. + + .. versionchanged:: 2.6 + .. versionadded:: 1.11 + """ + return self.__max_pool_size + + @property + def use_greenlets(self): + """Whether calling :meth:`start_request` assigns greenlet-local, + rather than thread-local, sockets. + + .. versionadded:: 2.4.2 + """ + return self.__use_greenlets + + @property + def nodes(self): + """List of all known nodes. + + Includes both nodes specified when this instance was created, + as well as nodes discovered through the replica set discovery + mechanism. + + .. versionadded:: 1.8 + """ + return self.__nodes + + @property + def auto_start_request(self): + """Is auto_start_request enabled? + """ + return self.__auto_start_request + + def get_document_class(self): + return self.__document_class + + def set_document_class(self, klass): + self.__document_class = klass + + document_class = property(get_document_class, set_document_class, + doc="""Default class to use for documents + returned from this client. + + .. versionadded:: 1.7 + """) + + @property + def tz_aware(self): + """Does this client return timezone-aware datetimes? + + .. versionadded:: 1.8 + """ + return self.__tz_aware + + @property + def max_bson_size(self): + """Return the maximum size BSON object the connected server + accepts in bytes. Defaults to 4MB in server < 1.7.4. + + .. versionadded:: 1.10 + """ + return self.__max_bson_size + + @property + def max_message_size(self): + """Return the maximum message size the connected server + accepts in bytes. + + .. versionadded:: 2.6 + """ + return self.__max_message_size + + def __simple_command(self, sock_info, dbname, spec): + """Send a command to the server. + """ + rqst_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec) + start = time.time() + try: + sock_info.sock.sendall(msg) + response = self.__receive_message_on_socket(1, rqst_id, sock_info) + except: + sock_info.close() + raise + + end = time.time() + response = helpers._unpack_response(response)['data'][0] + msg = "command %r failed: %%s" % spec + helpers._check_command_response(response, None, msg) + return response, end - start + + def __try_node(self, node): + """Try to connect to this node and see if it works for our connection + type. Returns ((host, port), ismaster, isdbgrid, res_time). + + :Parameters: + - `node`: The (host, port) pair to try. + """ + self.disconnect() + self.__host, self.__port = node + + # Call 'ismaster' directly so we can get a response time. + sock_info = self.__socket() + try: + response, res_time = self.__simple_command(sock_info, + 'admin', + {'ismaster': 1}) + finally: + self.__pool.maybe_return_socket(sock_info) + + # Are we talking to a mongos? + isdbgrid = response.get('msg', '') == 'isdbgrid' + + if "maxBsonObjectSize" in response: + self.__max_bson_size = response["maxBsonObjectSize"] + if "maxMessageSizeBytes" in response: + self.__max_message_size = response["maxMessageSizeBytes"] + else: + self.__max_message_size = 2 * self.max_bson_size + + # Replica Set? + if not self.__direct: + # Check that this host is part of the given replica set. + if self.__repl: + set_name = response.get('setName') + # The 'setName' field isn't returned by mongod before 1.6.2 + # so we can't assume that if it's missing this host isn't in + # the specified set. + if set_name and set_name != self.__repl: + raise ConfigurationError("%s:%d is not a member of " + "replica set %s" + % (node[0], node[1], self.__repl)) + if "hosts" in response: + self.__nodes = set([_partition_node(h) + for h in response["hosts"]]) + else: + # The user passed a seed list of standalone or + # mongos instances. + self.__nodes.add(node) + if response["ismaster"]: + return node, True, isdbgrid, res_time + elif "primary" in response: + candidate = _partition_node(response["primary"]) + return self.__try_node(candidate) + + # Explain why we aren't using this connection. + raise AutoReconnect('%s:%d is not primary or master' % node) + + # Direct connection + if response.get("arbiterOnly", False) and not self.__direct: + raise ConfigurationError("%s:%d is an arbiter" % node) + return node, response['ismaster'], isdbgrid, res_time + + def __pick_nearest(self, candidates): + """Return the 'nearest' candidate based on response time. + """ + latency = self.secondary_acceptable_latency_ms + # Only used for mongos high availability, res_time is in seconds. + fastest = min([res_time for candidate, res_time in candidates]) + near_candidates = [ + candidate for candidate, res_time in candidates + if res_time - fastest < latency / 1000.0 + ] + + node = random.choice(near_candidates) + # Clear the pool from the last choice. + self.disconnect() + self.__host, self.__port = node + return node + + def __find_node(self, seeds=None): + """Find a host, port pair suitable for our connection type. + + If only one host was supplied to __init__ see if we can connect + to it. Don't check if the host is a master/primary so we can make + a direct connection to read from a secondary or send commands to + an arbiter. + + If more than one host was supplied treat them as a seed list for + connecting to a replica set or to support high availability for + mongos. If connecting to a replica set try to find the primary + and fail if we can't, possibly updating any replSet information + on success. If a mongos seed list was provided find the "nearest" + mongos and return it. + + Otherwise we iterate through the list trying to find a host we can + send write operations to. + + Sets __host and __port so that :attr:`host` and :attr:`port` + will return the address of the connected host. Sets __is_primary to + True if this is a primary or master, else False. Sets __is_mongos + to True if the connection is to a mongos. + """ + errors = [] + mongos_candidates = [] + candidates = seeds or self.__nodes.copy() + for candidate in candidates: + try: + node, ismaster, isdbgrid, res_time = self.__try_node(candidate) + self.__is_primary = ismaster + self.__is_mongos = isdbgrid + # No need to calculate nearest if we only have one mongos. + if isdbgrid and not self.__direct: + mongos_candidates.append((node, res_time)) + continue + elif len(mongos_candidates): + raise ConfigurationError("Seed list cannot contain a mix " + "of mongod and mongos instances.") + return node + except OperationFailure: + # The server is available but something failed, probably auth. + raise + except Exception as why: + errors.append(str(why)) + + # If we have a mongos seed list, pick the "nearest" member. + if len(mongos_candidates): + self.__is_mongos = True + return self.__pick_nearest(mongos_candidates) + + # Otherwise, try any hosts we discovered that were not in the seed list. + for candidate in self.__nodes - candidates: + try: + node, ismaster, isdbgrid, _ = self.__try_node(candidate) + self.__is_primary = ismaster + self.__is_mongos = isdbgrid + return node + except Exception as why: + errors.append(str(why)) + # Couldn't find a suitable host. + self.disconnect() + raise AutoReconnect(', '.join(errors)) + + def __socket(self): + """Get a SocketInfo from the pool. + """ + host, port = (self.__host, self.__port) + if host is None or (port is None and '/' not in host): + host, port = self.__find_node() + + try: + if self.auto_start_request and not self.in_request(): + self.start_request() + + sock_info = self.__pool.get_socket((host, port)) + except socket.error as why: + self.disconnect() + + # Check if a unix domain socket + if host.endswith('.sock'): + host_details = "%s:" % host + else: + host_details = "%s:%d:" % (host, port) + raise AutoReconnect("could not connect to " + "%s %s" % (host_details, str(why))) + try: + self.__check_auth(sock_info) + except OperationFailure: + self.__pool.maybe_return_socket(sock_info) + raise + return sock_info + + def _ensure_connected(self, dummy): + """Ensure this client instance is connected to a mongod/s. + """ + host, port = (self.__host, self.__port) + if host is None or (port is None and '/' not in host): + self.__find_node() + + def disconnect(self): + """Disconnect from MongoDB. + + Disconnecting will close all underlying sockets in the connection + pool. If this instance is used again it will be automatically + re-opened. Care should be taken to make sure that :meth:`disconnect` + is not called in the middle of a sequence of operations in which + ordering is important. This could lead to unexpected results. + + .. seealso:: :meth:`end_request` + .. versionadded:: 1.3 + """ + self.__pool.reset() + self.__host = None + self.__port = None + + def close(self): + """Alias for :meth:`disconnect` + + Disconnecting will close all underlying sockets in the connection + pool. If this instance is used again it will be automatically + re-opened. Care should be taken to make sure that :meth:`disconnect` + is not called in the middle of a sequence of operations in which + ordering is important. This could lead to unexpected results. + + .. seealso:: :meth:`end_request` + .. versionadded:: 2.1 + """ + self.disconnect() + + def alive(self): + """Return ``False`` if there has been an error communicating with the + server, else ``True``. + + This method attempts to check the status of the server with minimal I/O. + The current thread / greenlet retrieves a socket from the pool (its + request socket if it's in a request, or a random idle socket if it's not + in a request) and checks whether calling `select`_ on it raises an + error. If there are currently no idle sockets, :meth:`alive` will + attempt to actually connect to the server. + + A more certain way to determine server availability is:: + + client.admin.command('ping') + + .. _select: http://docs.python.org/2/library/select.html#select.select + """ + # In the common case, a socket is available and was used recently, so + # calling select() on it is a reasonable attempt to see if the OS has + # reported an error. Note this can be wasteful: __socket implicitly + # calls select() if the socket hasn't been checked in the last second, + # or it may create a new socket, in which case calling select() is + # redundant. + sock_info = None + try: + try: + sock_info = self.__socket() + return not pool._closed(sock_info.sock) + except (socket.error, ConnectionFailure): + return False + finally: + self.__pool.maybe_return_socket(sock_info) + + def set_cursor_manager(self, manager_class): + """Set this client's cursor manager. + + Raises :class:`TypeError` if `manager_class` is not a subclass of + :class:`~pymongo.cursor_manager.CursorManager`. A cursor manager + handles closing cursors. Different managers can implement different + policies in terms of when to actually kill a cursor that has + been closed. + + :Parameters: + - `manager_class`: cursor manager to use + + .. versionchanged:: 2.1+ + Deprecated support for external cursor managers. + """ + warnings.warn("Support for external cursor managers is deprecated " + "and will be removed in PyMongo 3.0.", + DeprecationWarning, stacklevel=2) + manager = manager_class(self) + if not isinstance(manager, CursorManager): + raise TypeError("manager_class must be a subclass of " + "CursorManager") + + self.__cursor_manager = manager + + def __check_response_to_last_error(self, response): + """Check a response to a lastError message for errors. + + `response` is a byte string representing a response to the message. + If it represents an error response we raise OperationFailure. + + Return the response as a document. + """ + response = helpers._unpack_response(response) + + assert response["number_returned"] == 1 + error = response["data"][0] + + helpers._check_command_response(error, self.disconnect) + + error_msg = error.get("err", "") + if error_msg is None: + return error + if error_msg.startswith("not master"): + self.disconnect() + raise AutoReconnect(error_msg) + + details = error + # mongos returns the error code in an error object + # for some errors. + if "errObjects" in error: + for errobj in error["errObjects"]: + if errobj["err"] == error_msg: + details = errobj + break + + if "code" in details: + if details["code"] in (11000, 11001, 12582): + raise DuplicateKeyError(details["err"], details["code"]) + else: + raise OperationFailure(details["err"], details["code"]) + else: + raise OperationFailure(details["err"]) + + def __check_bson_size(self, message): + """Make sure the message doesn't include BSON documents larger + than the connected server will accept. + + :Parameters: + - `message`: message to check + """ + if len(message) == 3: + (request_id, data, max_doc_size) = message + if max_doc_size > self.__max_bson_size: + raise InvalidDocument("BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % + (max_doc_size, self.__max_bson_size)) + return (request_id, data) + else: + # get_more and kill_cursors messages + # don't include BSON documents. + return message + + def _send_message(self, message, with_last_error=False, check_primary=True): + """Say something to Mongo. + + Raises ConnectionFailure if the message cannot be sent. Raises + OperationFailure if `with_last_error` is ``True`` and the + response to the getLastError call returns an error. Return the + response from lastError, or ``None`` if `with_last_error` + is ``False``. + + :Parameters: + - `message`: message to send + - `with_last_error`: check getLastError status after sending the + message + - `check_primary`: don't try to write to a non-primary; see + kill_cursors for an exception to this rule + """ + if check_primary and not with_last_error and not self.is_primary: + # The write won't succeed, bail as if we'd done a getLastError + raise AutoReconnect("not master") + + sock_info = self.__socket() + try: + try: + (request_id, data) = self.__check_bson_size(message) + sock_info.sock.sendall(data) + # Safe mode. We pack the message together with a lastError + # message and send both. We then get the response (to the + # lastError) and raise OperationFailure if it is an error + # response. + rv = None + if with_last_error: + response = self.__receive_message_on_socket(1, request_id, + sock_info) + rv = self.__check_response_to_last_error(response) + + return rv + except OperationFailure: + raise + except (ConnectionFailure, socket.error) as e: + self.disconnect() + raise AutoReconnect(str(e)) + except: + sock_info.close() + raise + finally: + self.__pool.maybe_return_socket(sock_info) + + def __receive_data_on_socket(self, length, sock_info): + """Lowest level receive operation. + + Takes length to receive and repeatedly calls recv until able to + return a buffer of that length, raising ConnectionFailure on error. + """ + message = EMPTY + while length: + chunk = sock_info.sock.recv(length) + if chunk == EMPTY: + raise ConnectionFailure("connection closed") + length -= len(chunk) + message += chunk + return message + + def __receive_message_on_socket(self, operation, rqst_id, sock_info): + """Receive a message in response to `rqst_id` on `sock`. + + Returns the response data with the header removed. + """ + header = self.__receive_data_on_socket(16, sock_info) + length = struct.unpack(">> client = pymongo.MongoClient(auto_start_request=False) + >>> db = client.test + >>> _id = db.test_collection.insert({}) + >>> with client.start_request(): + ... for i in range(100): + ... db.test_collection.update({'_id': _id}, {'$set': {'i':i}}) + ... + ... # Definitely read the document after the final update completes + ... print(db.test_collection.find({'_id': _id})) + + If a thread or greenlet calls start_request multiple times, an equal + number of calls to :meth:`end_request` is required to end the request. + + .. versionchanged:: 2.4 + Now counts the number of calls to start_request and doesn't end + request until an equal number of calls to end_request. + + .. versionadded:: 2.2 + The :class:`~pymongo.pool.Request` return value. + :meth:`start_request` previously returned None + """ + self.__pool.start_request() + return pool.Request(self) + + def in_request(self): + """True if this thread is in a request, meaning it has a socket + reserved for its exclusive use. + """ + return self.__pool.in_request() + + def end_request(self): + """Undo :meth:`start_request`. If :meth:`end_request` is called as many + times as :meth:`start_request`, the request is over and this thread's + connection returns to the pool. Extra calls to :meth:`end_request` have + no effect. + + Ending a request allows the :class:`~socket.socket` that has + been reserved for this thread by :meth:`start_request` to be returned to + the pool. Other threads will then be able to re-use that + :class:`~socket.socket`. If your application uses many threads, or has + long-running threads that infrequently perform MongoDB operations, then + judicious use of this method can lead to performance gains. Care should + be taken, however, to make sure that :meth:`end_request` is not called + in the middle of a sequence of operations in which ordering is + important. This could lead to unexpected results. + """ + self.__pool.end_request() + + def __eq__(self, other): + if isinstance(other, self.__class__): + us = (self.__host, self.__port) + them = (other.__host, other.__port) + return us == them + return NotImplemented + + def __ne__(self, other): + return not self == other + + def __repr__(self): + if len(self.__nodes) == 1: + return "MongoClient(%r, %r)" % (self.__host, self.__port) + else: + return "MongoClient(%r)" % ["%s:%d" % n for n in self.__nodes] + + def __getattr__(self, name): + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :Parameters: + - `name`: the name of the database to get + """ + return database.Database(self, name) + + def __getitem__(self, name): + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :Parameters: + - `name`: the name of the database to get + """ + return self.__getattr__(name) + + def close_cursor(self, cursor_id): + """Close a single database cursor. + + Raises :class:`TypeError` if `cursor_id` is not an instance of + ``(int, long)``. What closing the cursor actually means + depends on this client's cursor manager. + + :Parameters: + - `cursor_id`: id of cursor to close + """ + if not isinstance(cursor_id, int): + raise TypeError("cursor_id must be an instance of (int, long)") + + self.__cursor_manager.close(cursor_id) + + def kill_cursors(self, cursor_ids): + """Send a kill cursors message with the given ids. + + Raises :class:`TypeError` if `cursor_ids` is not an instance of + ``list``. + + :Parameters: + - `cursor_ids`: list of cursor ids to kill + """ + if not isinstance(cursor_ids, list): + raise TypeError("cursor_ids must be a list") + return self._send_message( + message.kill_cursors(cursor_ids), check_primary=False) + + def server_info(self): + """Get information about the MongoDB server we're connected to. + """ + return self.admin.command("buildinfo") + + def database_names(self): + """Get a list of the names of all databases on the connected server. + """ + return [db["name"] for db in + self.admin.command("listDatabases")["databases"]] + + def drop_database(self, name_or_database): + """Drop a database. + + Raises :class:`TypeError` if `name_or_database` is not an instance of + :class:`basestring` (:class:`str` in python 3) or Database. + + :Parameters: + - `name_or_database`: the name of a database to drop, or a + :class:`~pymongo.database.Database` instance representing the + database to drop + """ + name = name_or_database + if isinstance(name, database.Database): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_database must be an instance of " + "%s or Database" % (str.__name__,)) + + self._purge_index(name) + self[name].command("dropDatabase") + + def copy_database(self, from_name, to_name, + from_host=None, username=None, password=None): + """Copy a database, potentially from another host. + + Raises :class:`TypeError` if `from_name` or `to_name` is not + an instance of :class:`basestring` (:class:`str` in python 3). + Raises :class:`~pymongo.errors.InvalidName` if `to_name` is + not a valid database name. + + If `from_host` is ``None`` the current host is used as the + source. Otherwise the database is copied from `from_host`. + + If the source database requires authentication, `username` and + `password` must be specified. + + :Parameters: + - `from_name`: the name of the source database + - `to_name`: the name of the target database + - `from_host` (optional): host name to copy from + - `username` (optional): username for source database + - `password` (optional): password for source database + + .. note:: Specifying `username` and `password` requires server + version **>= 1.3.3+**. + + .. versionadded:: 1.5 + """ + if not isinstance(from_name, str): + raise TypeError("from_name must be an instance " + "of %s" % (str.__name__,)) + if not isinstance(to_name, str): + raise TypeError("to_name must be an instance " + "of %s" % (str.__name__,)) + + database._check_name(to_name) + + command = {"fromdb": from_name, "todb": to_name} + + if from_host is not None: + command["fromhost"] = from_host + + try: + self.start_request() + + if username is not None: + nonce = self.admin.command("copydbgetnonce", + fromhost=from_host)["nonce"] + command["username"] = username + command["nonce"] = nonce + command["key"] = auth._auth_key(nonce, username, password) + + return self.admin.command("copydb", **command) + finally: + self.end_request() + + def get_default_database(self): + """Get the database named in the MongoDB connection URI. + + >>> uri = 'mongodb://host/my_database' + >>> client = MongoClient(uri) + >>> db = client.get_default_database() + >>> assert db.name == 'my_database' + + Useful in scripts where you want to choose which database to use + based only on the URI in a configuration file. + """ + if self.__default_database_name is None: + raise ConfigurationError('No default database defined') + + return self[self.__default_database_name] + + @property + def is_locked(self): + """Is this server locked? While locked, all write operations + are blocked, although read operations may still be allowed. + Use :meth:`unlock` to unlock. + + .. versionadded:: 2.0 + """ + ops = self.admin.current_op() + return bool(ops.get('fsyncLock', 0)) + + def fsync(self, **kwargs): + """Flush all pending writes to datafiles. + + :Parameters: + + Optional parameters can be passed as keyword arguments: + + - `lock`: If True lock the server to disallow writes. + - `async`: If True don't block while synchronizing. + + .. warning:: `async` and `lock` can not be used together. + + .. warning:: MongoDB does not support the `async` option + on Windows and will raise an exception on that + platform. + + .. versionadded:: 2.0 + """ + self.admin.command("fsync", **kwargs) + + def unlock(self): + """Unlock a previously locked server. + + .. versionadded:: 2.0 + """ + self.admin['$cmd'].sys.unlock.find_one() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + + def __iter__(self): + return self + + def __next__(self): + raise TypeError("'MongoClient' object is not iterable") + diff --git a/asyncio_mongo/_pymongo/mongo_replica_set_client.py b/asyncio_mongo/_pymongo/mongo_replica_set_client.py new file mode 100644 index 0000000..763aa72 --- /dev/null +++ b/asyncio_mongo/_pymongo/mongo_replica_set_client.py @@ -0,0 +1,1855 @@ +# Copyright 2011-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools for connecting to a MongoDB replica set. + +.. seealso:: :doc:`/examples/high_availability` for more examples of + how to connect to a replica set. + +To get a :class:`~pymongo.database.Database` instance from a +:class:`MongoReplicaSetClient` use either dictionary-style or +attribute-style access: + +.. doctest:: + + >>> from asyncio_mongo._pymongo import MongoReplicaSetClient + >>> c = MongoReplicaSetClient('localhost:27017', replicaSet='repl0') + >>> c.test_database + Database(MongoReplicaSetClient([u'...', u'...']), u'test_database') + >>> c['test_database'] + Database(MongoReplicaSetClient([u'...', u'...']), u'test_database') +""" + +import atexit +import datetime +import socket +import struct +import threading +import time +import warnings +import weakref + +from asyncio_mongo._bson.py3compat import b +from asyncio_mongo._pymongo import (auth, + common, + database, + helpers, + message, + pool, + thread_util, + uri_parser) +from asyncio_mongo._pymongo.read_preferences import ( + ReadPreference, select_member, modes, MovingAverage) +from asyncio_mongo._pymongo.errors import (AutoReconnect, + ConfigurationError, + ConnectionFailure, + DuplicateKeyError, + InvalidDocument, + OperationFailure, + InvalidOperation) + +EMPTY = b("") +MAX_BSON_SIZE = 4 * 1024 * 1024 +MAX_RETRY = 3 + +# Member states +PRIMARY = 1 +SECONDARY = 2 +OTHER = 3 + +MONITORS = set() + +def register_monitor(monitor): + ref = weakref.ref(monitor, _on_monitor_deleted) + MONITORS.add(ref) + +def _on_monitor_deleted(ref): + """Remove the weakreference from the set + of active MONITORS. We no longer + care about keeping track of it + """ + MONITORS.remove(ref) + +def shutdown_monitors(): + # Keep a local copy of MONITORS as + # shutting down threads has a side effect + # of removing them from the MONITORS set() + monitors = list(MONITORS) + for ref in monitors: + monitor = ref() + if monitor: + monitor.shutdown() + monitor.join() +atexit.register(shutdown_monitors) + +def _partition_node(node): + """Split a host:port string returned from mongod/s into + a (host, int(port)) pair needed for socket.connect(). + """ + host = node + port = 27017 + idx = node.rfind(':') + if idx != -1: + host, port = node[:idx], int(node[idx + 1:]) + if host.startswith('['): + host = host[1:-1] + return host, port + + +# Concurrency notes: A MongoReplicaSetClient keeps its view of the replica-set +# state in an RSState instance. RSStates are immutable, except for +# host-pinning. Pools, which are internally thread / greenlet safe, can be +# copied from old to new RSStates safely. The client updates its view of the +# set's state not by modifying its RSState but by replacing it with an updated +# copy. + +# In __init__, MongoReplicaSetClient gets a list of potential members called +# 'seeds' from its initial parameters, and calls refresh(). refresh() iterates +# over the the seeds in arbitrary order looking for a member it can connect to. +# Once it finds one, it calls 'ismaster' and sets self.__hosts to the list of +# members in the response, and connects to the rest of the members. refresh() +# sets the MongoReplicaSetClient's RSState. Finally, __init__ launches the +# replica-set monitor. + +# The monitor calls refresh() every 30 seconds, or whenever the client has +# encountered an error that prompts it to wake the monitor. + +# Every method that accesses the RSState multiple times within the method makes +# a local reference first and uses that throughout, so it's isolated from a +# concurrent method replacing the RSState with an updated copy. This technique +# avoids the need to lock around accesses to the RSState. + + +class RSState(object): + def __init__( + self, threadlocal, host_to_member=None, arbiters=None, writer=None, + error_message='No primary available'): + """An immutable snapshot of the client's view of the replica set state. + + :Parameters: + - `threadlocal`: Thread- or greenlet-local storage + - `host_to_member`: Optional dict: (host, port) -> Member instance + - `arbiters`: Optional sequence of arbiters as (host, port) + - `writer`: Optional (host, port) of primary + - `error_message`: Optional error if `writer` is None + """ + self._threadlocal = threadlocal # threading.local or gevent local + self._arbiters = frozenset(arbiters or []) # set of (host, port) + self._writer = writer # (host, port) of the primary, or None + self._error_message = error_message + self._host_to_member = host_to_member or {} + self._hosts = frozenset(self._host_to_member) + self._members = frozenset(list(self._host_to_member.values())) + + if writer and self._host_to_member[writer].up: + self._primary_member = self._host_to_member[writer] + else: + self._primary_member = None + + def clone_with_host_down(self, host, error_message): + """Get a clone, marking as "down" the member with the given (host, port) + """ + members = self._host_to_member.copy() + down_member = members.pop(host, None) + if down_member: + members[host] = down_member.clone_down() + + if host == self.writer: + # The primary went down; record the error message. + return RSState( + self._threadlocal, members, self._arbiters, + None, error_message) + else: + # Some other host went down. Keep our current primary or, if it's + # already down, keep our current error message. + return RSState( + self._threadlocal, members, self._arbiters, + self._writer, self._error_message) + + def clone_without_writer(self, threadlocal): + """Get a clone without a primary. Unpins all threads. + + :Parameters: + - `threadlocal`: Thread- or greenlet-local storage + """ + return RSState( + threadlocal, self._host_to_member.copy(), self._arbiters, None) + + @property + def arbiters(self): + """Set of (host, port) pairs.""" + return self._arbiters + + @property + def writer(self): + """(host, port) of primary, or None.""" + return self._writer + + @property + def primary_member(self): + return self._primary_member + + @property + def hosts(self): + """Set of (host, port) tuples of data members of the replica set.""" + return self._hosts + + @property + def members(self): + """Set of Member instances.""" + return self._members + + @property + def error_message(self): + """The error, if any, raised when trying to connect to the primary""" + return self._error_message + + @property + def secondaries(self): + """Set of (host, port) pairs.""" + # Unlike the other properties, this isn't cached because it isn't used + # in regular operations. + return set([ + host for host, member in list(self._host_to_member.items()) + if member.is_secondary]) + + def get(self, host): + """Return a Member instance or None for the given (host, port).""" + return self._host_to_member.get(host) + + def pin_host(self, host, mode, tag_sets, latency): + """Pin this thread / greenlet to a member. + + `host` is a (host, port) pair. The remaining parameters are a read + preference. + """ + # Fun fact: Unlike in thread_util.ThreadIdent, we needn't lock around + # assignment here. Assignment to a threadlocal is only unsafe if it + # can cause other Python code to run implicitly. + self._threadlocal.host = host + self._threadlocal.read_preference = (mode, tag_sets, latency) + + def keep_pinned_host(self, mode, tag_sets, latency): + """Does a read pref match the last used by this thread / greenlet?""" + return self._threadlocal.read_preference == (mode, tag_sets, latency) + + @property + def pinned_host(self): + """The (host, port) last used by this thread / greenlet, or None.""" + return getattr(self._threadlocal, 'host', None) + + def unpin_host(self): + """Forget this thread / greenlet's last used member.""" + self._threadlocal.host = self._threadlocal.read_preference = None + + @property + def threadlocal(self): + return self._threadlocal + + def __str__(self): + return '' % ( + ', '.join(str(member) for member in self._host_to_member.values()), + self.writer and '%s:%s' % self.writer or None) + + +class Monitor(object): + """Base class for replica set monitors. + """ + _refresh_interval = 30 + + def __init__(self, rsc, event_class): + self.rsc = weakref.proxy(rsc, self.shutdown) + self.timer = event_class() + self.refreshed = event_class() + self.started_event = event_class() + self.stopped = False + + def start_sync(self): + """Start the Monitor and block until it's really started. + """ + self.start() # Implemented in subclasses. + self.started_event.wait(5) + + def shutdown(self, dummy=None): + """Signal the monitor to shutdown. + """ + self.stopped = True + self.timer.set() + + def schedule_refresh(self): + """Refresh immediately + """ + if not self.isAlive(): + raise InvalidOperation( + "Monitor thread is dead: Perhaps started before a fork?") + + self.refreshed.clear() + self.timer.set() + + def wait_for_refresh(self, timeout_seconds): + """Block until a scheduled refresh completes + """ + self.refreshed.wait(timeout_seconds) + + def monitor(self): + """Run until the RSC is collected or an + unexpected error occurs. + """ + self.started_event.set() + while True: + self.timer.wait(Monitor._refresh_interval) + if self.stopped: + break + self.timer.clear() + + try: + try: + self.rsc.refresh() + finally: + self.refreshed.set() + except AutoReconnect: + pass + + # RSC has been collected or there + # was an unexpected error. + except: + break + + def isAlive(self): + raise NotImplementedError() + + +class MonitorThread(threading.Thread, Monitor): + """Thread based replica set monitor. + """ + def __init__(self, rsc): + Monitor.__init__(self, rsc, threading.Event) + threading.Thread.__init__(self) + self.setName("ReplicaSetMonitorThread") + + # Track whether the thread has started. (Greenlets track this already.) + self.started = False + + def start(self): + self.started = True + super(MonitorThread, self).start() + + def run(self): + """Override Thread's run method. + """ + self.monitor() + + +have_gevent = False +try: + from gevent import Greenlet + from gevent.event import Event + + # Used by ReplicaSetConnection + from gevent.local import local as gevent_local + have_gevent = True + + class MonitorGreenlet(Monitor, Greenlet): + """Greenlet based replica set monitor. + """ + def __init__(self, rsc): + Monitor.__init__(self, rsc, Event) + Greenlet.__init__(self) + + # Don't override `run` in a Greenlet. Add _run instead. + # Refer to gevent's Greenlet docs and source for more + # information. + def _run(self): + """Define Greenlet's _run method. + """ + self.monitor() + + def isAlive(self): + # Gevent defines bool(Greenlet) as True if it's alive. + return bool(self) + +except ImportError: + pass + + +class Member(object): + """Immutable representation of one member of a replica set. + + :Parameters: + - `host`: A (host, port) pair + - `connection_pool`: A Pool instance + - `ismaster_response`: A dict, MongoDB's ismaster response + - `ping_time`: A MovingAverage instance + - `up`: Whether we think this member is available + """ + # For unittesting only. Use under no circumstances! + _host_to_ping_time = {} + + def __init__(self, host, connection_pool, ismaster_response, ping_time, up): + self.host = host + self.pool = connection_pool + self.ismaster_response = ismaster_response + self.ping_time = ping_time + self.up = up + + if ismaster_response['ismaster']: + self.state = PRIMARY + elif ismaster_response.get('secondary'): + self.state = SECONDARY + else: + self.state = OTHER + + self.tags = ismaster_response.get('tags', {}) + self.max_bson_size = ismaster_response.get( + 'maxBsonObjectSize', MAX_BSON_SIZE) + self.max_message_size = ismaster_response.get( + 'maxMessageSizeBytes', 2 * self.max_bson_size) + + def clone_with(self, ismaster_response, ping_time_sample): + """Get a clone updated with ismaster response and a single ping time. + """ + ping_time = self.ping_time.clone_with(ping_time_sample) + return Member(self.host, self.pool, ismaster_response, ping_time, True) + + def clone_down(self): + """Get a clone of this Member, but with up=False. + """ + return Member( + self.host, self.pool, self.ismaster_response, self.ping_time, + False) + + @property + def is_primary(self): + return self.state == PRIMARY + + @property + def is_secondary(self): + return self.state == SECONDARY + + def get_avg_ping_time(self): + """Get a moving average of this member's ping times. + """ + if self.host in Member._host_to_ping_time: + # Simulate ping times for unittesting + return Member._host_to_ping_time[self.host] + + return self.ping_time.get() + + def matches_mode(self, mode): + if mode == ReadPreference.PRIMARY and not self.is_primary: + return False + + if mode == ReadPreference.SECONDARY and not self.is_secondary: + return False + + # If we're not primary or secondary, then we're in a state like + # RECOVERING and we don't match any mode + return self.is_primary or self.is_secondary + + def matches_tags(self, tags): + """Return True if this member's tags are a superset of the passed-in + tags. E.g., if this member is tagged {'dc': 'ny', 'rack': '1'}, + then it matches {'dc': 'ny'}. + """ + for key, value in list(tags.items()): + if key not in self.tags or self.tags[key] != value: + return False + + return True + + def matches_tag_sets(self, tag_sets): + """Return True if this member matches any of the tag sets, e.g. + [{'dc': 'ny'}, {'dc': 'la'}, {}] + """ + for tags in tag_sets: + if self.matches_tags(tags): + return True + + return False + + def __str__(self): + return '' % ( + self.host[0], self.host[1], self.is_primary, self.up) + + +class MongoReplicaSetClient(common.BaseObject): + """Connection to a MongoDB replica set. + """ + + def __init__(self, hosts_or_uri=None, max_pool_size=100, + document_class=dict, tz_aware=False, _connect=True, **kwargs): + """Create a new connection to a MongoDB replica set. + + The resultant client object has connection-pooling built + in. It also performs auto-reconnection when necessary. If an + operation fails because of a connection error, + :class:`~pymongo.errors.ConnectionFailure` is raised. If + auto-reconnection will be performed, + :class:`~pymongo.errors.AutoReconnect` will be + raised. Application code should handle this exception + (recognizing that the operation failed) and then continue to + execute. + + Raises :class:`~pymongo.errors.ConnectionFailure` if + the connection cannot be made. + + The `hosts_or_uri` parameter can be a full `mongodb URI + `_, in addition to + a string of `host:port` pairs (e.g. 'host1:port1,host2:port2'). + If `hosts_or_uri` is None 'localhost:27017' will be used. + + .. note:: Instances of :class:`MongoReplicaSetClient` start a + background task to monitor the state of the replica set. This allows + it to quickly respond to changes in replica set configuration. + Before discarding an instance of :class:`MongoReplicaSetClient` make + sure you call :meth:`~close` to ensure that the monitor task is + cleanly shut down. + + .. note:: A :class:`MongoReplicaSetClient` created before a call to + ``os.fork()`` is invalid after the fork. Applications should either + fork before creating the client, or recreate the client after a + fork. + + :Parameters: + - `hosts_or_uri` (optional): A MongoDB URI or string of `host:port` + pairs. If a host is an IPv6 literal it must be enclosed in '[' and + ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for + localhost) + - `max_pool_size` (optional): The maximum number of connections + each pool will open simultaneously. If this is set, operations + will block if there are `max_pool_size` outstanding connections + from the pool. Defaults to 100. + - `document_class` (optional): default class to use for + documents returned from queries on this client + - `tz_aware` (optional): if ``True``, + :class:`~datetime.datetime` instances returned as values + in a document by this :class:`MongoReplicaSetClient` will be timezone + aware (otherwise they will be naive) + - `replicaSet`: (required) The name of the replica set to connect to. + The driver will verify that each host it connects to is a member of + this replica set. Can be passed as a keyword argument or as a + MongoDB URI option. + + | **Other optional parameters can be passed as keyword arguments:** + + - `host`: For compatibility with :class:`~mongo_client.MongoClient`. + If both `host` and `hosts_or_uri` are specified `host` takes + precedence. + - `port`: For compatibility with :class:`~mongo_client.MongoClient`. + The default port number to use for hosts. + - `socketTimeoutMS`: (integer) How long (in milliseconds) a send or + receive on a socket can take before timing out. + - `connectTimeoutMS`: (integer) How long (in milliseconds) a + connection can take to be opened before timing out. + - `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a + thread will wait for a socket from the pool if the pool has no + free sockets. Defaults to ``None`` (no timeout). + - `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give + the number of threads allowed to wait for a socket at one time. + Defaults to ``None`` (no waiters). + - `auto_start_request`: If ``True``, each thread that accesses + this :class:`MongoReplicaSetClient` has a socket allocated to it + for the thread's lifetime, for each member of the set. For + :class:`~pymongo.read_preferences.ReadPreference` PRIMARY, + auto_start_request=True ensures consistent reads, even if you read + after an unacknowledged write. For read preferences other than + PRIMARY, there are no consistency guarantees. Default to ``False``. + - `use_greenlets`: If ``True``, use a background Greenlet instead of + a background thread to monitor state of replica set. Additionally, + :meth:`start_request()` assigns a greenlet-local, rather than + thread-local, socket. + `use_greenlets` with :class:`MongoReplicaSetClient` requires + `Gevent `_ to be installed. + + | **Write Concern options:** + + - `w`: (integer or string) Write operations will block until they have + been replicated to the specified number or tagged set of servers. + `w=` always includes the replica set primary (e.g. w=3 means + write to the primary and wait until replicated to **two** + secondaries). Passing w=0 **disables write acknowledgement** and all + other write concern options. + - `wtimeout`: (integer) Used in conjunction with `w`. Specify a value + in milliseconds to control how long to wait for write propagation + to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. + - `j`: If ``True`` block until write operations have been committed + to the journal. Ignored if the server is running without journaling. + - `fsync`: If ``True`` force the database to fsync all files before + returning. When used with `j` the server awaits the next group + commit before returning. + + | **Read preference options:** + + - `read_preference`: The read preference for this client. + See :class:`~pymongo.read_preferences.ReadPreference` for available + options. + - `tag_sets`: Read from replica-set members with these tags. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." :class:`MongoReplicaSetClient` tries each set of + tags in turn until it finds a set of tags with at least one matching + member. + - `secondary_acceptable_latency_ms`: (integer) Any replica-set member + whose ping time is within secondary_acceptable_latency_ms of the + nearest member may accept reads. Default 15 milliseconds. + **Ignored by mongos** and must be configured on the command line. + See the localThreshold_ option for more information. + + | **SSL configuration:** + + - `ssl`: If ``True``, create the connection to the servers using SSL. + - `ssl_keyfile`: The private keyfile used to identify the local + connection against mongod. If included with the ``certfile` then + only the ``ssl_certfile`` is needed. Implies ``ssl=True``. + - `ssl_certfile`: The certificate file used to identify the local + connection against mongod. Implies ``ssl=True``. + - `ssl_cert_reqs`: Specifies whether a certificate is required from + the other side of the connection, and whether it will be validated + if provided. It must be one of the three values ``ssl.CERT_NONE`` + (certificates ignored), ``ssl.CERT_OPTIONAL`` + (not required, but validated if provided), or ``ssl.CERT_REQUIRED`` + (required and validated). If the value of this parameter is not + ``ssl.CERT_NONE``, then the ``ssl_ca_certs`` parameter must point + to a file of CA certificates. Implies ``ssl=True``. + - `ssl_ca_certs`: The ca_certs file contains a set of concatenated + "certification authority" certificates, which are used to validate + certificates passed from the other end of the connection. + Implies ``ssl=True``. + + .. versionchanged:: 2.5 + Added additional ssl options + .. versionadded:: 2.4 + + .. _localThreshold: http://docs.mongodb.org/manual/reference/mongos/#cmdoption-mongos--localThreshold + """ + self.__opts = {} + self.__seeds = set() + self.__index_cache = {} + self.__auth_credentials = {} + + self.__max_pool_size = common.validate_positive_integer_or_none( + 'max_pool_size', max_pool_size) + self.__tz_aware = common.validate_boolean('tz_aware', tz_aware) + self.__document_class = document_class + self.__monitor = None + + # Compatibility with mongo_client.MongoClient + host = kwargs.pop('host', hosts_or_uri) + + port = kwargs.pop('port', 27017) + if not isinstance(port, int): + raise TypeError("port must be an instance of int") + + username = None + password = None + self.__default_database_name = None + options = {} + if host is None: + self.__seeds.add(('localhost', port)) + elif '://' in host: + res = uri_parser.parse_uri(host, port) + self.__seeds.update(res['nodelist']) + username = res['username'] + password = res['password'] + self.__default_database_name = res['database'] + options = res['options'] + else: + self.__seeds.update(uri_parser.split_hosts(host, port)) + + # _pool_class and _monitor_class are for deep customization of PyMongo, + # e.g. Motor. SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO 10GEN. + self.pool_class = kwargs.pop('_pool_class', pool.Pool) + monitor_class = kwargs.pop('_monitor_class', None) + + for option, value in kwargs.items(): + option, value = common.validate(option, value) + self.__opts[option] = value + self.__opts.update(options) + + self.__use_greenlets = self.__opts.get('use_greenlets', False) + if self.__use_greenlets and not have_gevent: + raise ConfigurationError( + "The gevent module is not available. " + "Install the gevent package from PyPI.") + + self.__rs_state = RSState(self.__make_threadlocal()) + + self.__request_counter = thread_util.Counter(self.__use_greenlets) + + self.__auto_start_request = self.__opts.get('auto_start_request', False) + if self.__auto_start_request: + self.start_request() + + self.__name = self.__opts.get('replicaset') + if not self.__name: + raise ConfigurationError("the replicaSet " + "keyword parameter is required.") + + self.__net_timeout = self.__opts.get('sockettimeoutms') + self.__conn_timeout = self.__opts.get('connecttimeoutms') + self.__wait_queue_timeout = self.__opts.get('waitqueuetimeoutms') + self.__wait_queue_multiple = self.__opts.get('waitqueuemultiple') + self.__use_ssl = self.__opts.get('ssl', None) + self.__ssl_keyfile = self.__opts.get('ssl_keyfile', None) + self.__ssl_certfile = self.__opts.get('ssl_certfile', None) + self.__ssl_cert_reqs = self.__opts.get('ssl_cert_reqs', None) + self.__ssl_ca_certs = self.__opts.get('ssl_ca_certs', None) + + ssl_kwarg_keys = [k for k in list(kwargs.keys()) if k.startswith('ssl_')] + if not self.__use_ssl and ssl_kwarg_keys: + raise ConfigurationError("ssl has not been enabled but the " + "following ssl parameters have been set: " + "%s. Please set `ssl=True` or remove." + % ', '.join(ssl_kwarg_keys)) + + if self.__ssl_cert_reqs and not self.__ssl_ca_certs: + raise ConfigurationError("If `ssl_cert_reqs` is not " + "`ssl.CERT_NONE` then you must " + "include `ssl_ca_certs` to be able " + "to validate the server.") + + if ssl_kwarg_keys and self.__use_ssl is None: + # ssl options imply ssl = True + self.__use_ssl = True + + if self.__use_ssl and not common.HAS_SSL: + raise ConfigurationError("The ssl module is not available. If you " + "are using a python version previous to " + "2.6 you must install the ssl package " + "from PyPI.") + + super(MongoReplicaSetClient, self).__init__(**self.__opts) + if self.slave_okay: + warnings.warn("slave_okay is deprecated. Please " + "use read_preference instead.", DeprecationWarning, + stacklevel=2) + + if _connect: + try: + self.refresh() + except AutoReconnect as e: + # ConnectionFailure makes more sense here than AutoReconnect + raise ConnectionFailure(str(e)) + + if username: + mechanism = options.get('authmechanism', 'MONGODB-CR') + source = ( + options.get('authsource') + or self.__default_database_name + or 'admin') + + credentials = auth._build_credentials_tuple(mechanism, + source, + str(username), + str(password), + options) + try: + self._cache_credentials(source, credentials, _connect) + except OperationFailure as exc: + raise ConfigurationError(str(exc)) + + # Start the monitor after we know the configuration is correct. + if monitor_class: + self.__monitor = monitor_class(self) + elif self.__use_greenlets: + self.__monitor = MonitorGreenlet(self) + else: + self.__monitor = MonitorThread(self) + self.__monitor.setDaemon(True) + register_monitor(self.__monitor) + + if _connect: + # Wait for the monitor to really start. Otherwise if we return to + # caller and caller forks immediately, the monitor could think it's + # still alive in the child process when it really isn't. + # See http://bugs.python.org/issue18418. + self.__monitor.start_sync() + + def _cached(self, dbname, coll, index): + """Test if `index` is cached. + """ + cache = self.__index_cache + now = datetime.datetime.utcnow() + return (dbname in cache and + coll in cache[dbname] and + index in cache[dbname][coll] and + now < cache[dbname][coll][index]) + + def _cache_index(self, dbase, collection, index, cache_for): + """Add an index to the index cache for ensure_index operations. + """ + now = datetime.datetime.utcnow() + expire = datetime.timedelta(seconds=cache_for) + now + + if dbase not in self.__index_cache: + self.__index_cache[dbase] = {} + self.__index_cache[dbase][collection] = {} + self.__index_cache[dbase][collection][index] = expire + + elif collection not in self.__index_cache[dbase]: + self.__index_cache[dbase][collection] = {} + self.__index_cache[dbase][collection][index] = expire + + else: + self.__index_cache[dbase][collection][index] = expire + + def _purge_index(self, database_name, + collection_name=None, index_name=None): + """Purge an index from the index cache. + + If `index_name` is None purge an entire collection. + + If `collection_name` is None purge an entire database. + """ + if not database_name in self.__index_cache: + return + + if collection_name is None: + del self.__index_cache[database_name] + return + + if not collection_name in self.__index_cache[database_name]: + return + + if index_name is None: + del self.__index_cache[database_name][collection_name] + return + + if index_name in self.__index_cache[database_name][collection_name]: + del self.__index_cache[database_name][collection_name][index_name] + + def _cache_credentials(self, source, credentials, connect=True): + """Add credentials to the database authentication cache + for automatic login when a socket is created. If `connect` is True, + verify the credentials on the server first. + + Raises OperationFailure if other credentials are already stored for + this source. + """ + if source in self.__auth_credentials: + # Nothing to do if we already have these credentials. + if credentials == self.__auth_credentials[source]: + return + raise OperationFailure('Another user is already authenticated ' + 'to this database. You must logout first.') + + if connect: + # Try to authenticate even during failover. + member = select_member( + self.__rs_state.members, ReadPreference.PRIMARY_PREFERRED) + + if not member: + raise AutoReconnect( + "No replica set members available for authentication") + + sock_info = self.__socket(member) + try: + # Since __check_auth was called in __socket + # there is no need to call it here. + auth.authenticate(credentials, sock_info, self.__simple_command) + sock_info.authset.add(credentials) + finally: + member.pool.maybe_return_socket(sock_info) + + self.__auth_credentials[source] = credentials + + def _purge_credentials(self, source): + """Purge credentials from the database authentication cache. + """ + if source in self.__auth_credentials: + del self.__auth_credentials[source] + + def __check_auth(self, sock_info): + """Authenticate using cached database credentials. + """ + if self.__auth_credentials or sock_info.authset: + cached = set(self.__auth_credentials.values()) + + authset = sock_info.authset.copy() + + # Logout any credentials that no longer exist in the cache. + for credentials in authset - cached: + self.__simple_command(sock_info, credentials[1], {'logout': 1}) + sock_info.authset.discard(credentials) + + for credentials in cached - authset: + auth.authenticate(credentials, + sock_info, self.__simple_command) + sock_info.authset.add(credentials) + + @property + def seeds(self): + """The seed list used to connect to this replica set. + + A sequence of (host, port) pairs. + """ + return self.__seeds + + @property + def hosts(self): + """All active and passive (priority 0) replica set + members known to this client. This does not include + hidden or slaveDelay members, or arbiters. + + A sequence of (host, port) pairs. + """ + return self.__rs_state.hosts + + @property + def primary(self): + """The (host, port) of the current primary of the replica set. + + Returns None if there is no primary. + """ + return self.__rs_state.writer + + @property + def secondaries(self): + """The secondary members known to this client. + + A sequence of (host, port) pairs. + """ + return self.__rs_state.secondaries + + @property + def arbiters(self): + """The arbiters known to this client. + + A sequence of (host, port) pairs. + """ + return self.__rs_state.arbiters + + @property + def is_mongos(self): + """If this instance is connected to mongos (always False). + + .. versionadded:: 2.3 + """ + return False + + @property + def max_pool_size(self): + """The maximum number of sockets the pool will open concurrently. + + When the pool has reached `max_pool_size`, operations block waiting for + a socket to be returned to the pool. If ``waitQueueTimeoutMS`` is set, + a blocked operation will raise :exc:`~pymongo.errors.ConnectionFailure` + after a timeout. By default ``waitQueueTimeoutMS`` is not set. + + .. warning:: SIGNIFICANT BEHAVIOR CHANGE in 2.6. Previously, this + parameter would limit only the idle sockets the pool would hold + onto, not the number of open sockets. The default has also changed + to 100. + + .. versionchanged:: 2.6 + """ + return self.__max_pool_size + + @property + def use_greenlets(self): + """Whether calling :meth:`start_request` assigns greenlet-local, + rather than thread-local, sockets. + + .. versionadded:: 2.4.2 + """ + return self.__use_greenlets + + def get_document_class(self): + """document_class getter""" + return self.__document_class + + def set_document_class(self, klass): + """document_class setter""" + self.__document_class = klass + + document_class = property(get_document_class, set_document_class, + doc="""Default class to use for documents + returned from this client. + """) + + @property + def tz_aware(self): + """Does this client return timezone-aware datetimes? + """ + return self.__tz_aware + + @property + def max_bson_size(self): + """Returns the maximum size BSON object the connected primary + accepts in bytes. Defaults to 4MB in server < 1.7.4. Returns + 0 if no primary is available. + """ + rs_state = self.__rs_state + if rs_state.primary_member: + return rs_state.primary_member.max_bson_size + return 0 + + @property + def max_message_size(self): + """Returns the maximum message size the connected primary + accepts in bytes. Returns 0 if no primary is available. + """ + rs_state = self.__rs_state + if rs_state.primary_member: + return rs_state.primary_member.max_message_size + return 0 + + @property + def auto_start_request(self): + """Is auto_start_request enabled? + """ + return self.__auto_start_request + + def __simple_command(self, sock_info, dbname, spec): + """Send a command to the server. + Returns (response, ping_time in seconds). + """ + rqst_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec) + start = time.time() + try: + sock_info.sock.sendall(msg) + response = self.__recv_msg(1, rqst_id, sock_info) + except: + sock_info.close() + raise + + end = time.time() + response = helpers._unpack_response(response)['data'][0] + msg = "command %r failed: %%s" % spec + helpers._check_command_response(response, None, msg) + return response, end - start + + def __is_master(self, host): + """Directly call ismaster. + Returns (response, connection_pool, ping_time in seconds). + """ + connection_pool = self.pool_class( + host, + self.__max_pool_size, + self.__net_timeout, + self.__conn_timeout, + self.__use_ssl, + wait_queue_timeout=self.__wait_queue_timeout, + wait_queue_multiple=self.__wait_queue_multiple, + use_greenlets=self.__use_greenlets, + ssl_keyfile=self.__ssl_keyfile, + ssl_certfile=self.__ssl_certfile, + ssl_cert_reqs=self.__ssl_cert_reqs, + ssl_ca_certs=self.__ssl_ca_certs) + + if self.in_request(): + connection_pool.start_request() + + sock_info = connection_pool.get_socket() + try: + response, ping_time = self.__simple_command( + sock_info, 'admin', {'ismaster': 1} + ) + + connection_pool.maybe_return_socket(sock_info) + return response, connection_pool, ping_time + except (ConnectionFailure, socket.error): + connection_pool.discard_socket(sock_info) + raise + + def __schedule_refresh(self, sync=False): + """Awake the monitor to update our view of the replica set's state. + + If `sync` is True, block until the refresh completes. + + If multiple application threads call __schedule_refresh while refresh + is in progress, the work of refreshing the state is only performed + once. + """ + self.__monitor.schedule_refresh() + if sync: + self.__monitor.wait_for_refresh(timeout_seconds=5) + + def __make_threadlocal(self): + if self.__use_greenlets: + return gevent_local() + else: + return threading.local() + + def refresh(self): + """Iterate through the existing host list, or possibly the + seed list, to update the list of hosts and arbiters in this + replica set. + """ + # Only one thread / greenlet calls refresh() at a time: the one + # running __init__() or the monitor. We won't modify the state, only + # replace it at the end. + rs_state = self.__rs_state + errors = [] + if rs_state.hosts: + # Try first those hosts we think are up, then the down ones. + nodes = sorted( + rs_state.hosts, key=lambda host: rs_state.get(host).up) + else: + nodes = self.__seeds + + hosts = set() + + # This will become the new RSState. + members = {} + arbiters = set() + writer = None + + # Look for first member from which we can get a list of all members. + for node in nodes: + member, sock_info = rs_state.get(node), None + try: + if member: + sock_info = self.__socket(member, force=True) + response, ping_time = self.__simple_command( + sock_info, 'admin', {'ismaster': 1}) + member.pool.maybe_return_socket(sock_info) + new_member = member.clone_with(response, ping_time) + else: + response, pool, ping_time = self.__is_master(node) + new_member = Member( + node, pool, response, MovingAverage([ping_time]), True) + + # Check that this host is part of the given replica set. + set_name = response.get('setName') + # The 'setName' field isn't returned by mongod before 1.6.2 + # so we can't assume that if it's missing this host isn't in + # the specified set. + if set_name and set_name != self.__name: + host, port = node + raise ConfigurationError("%s:%d is not a member of " + "replica set %s" + % (host, port, self.__name)) + if "arbiters" in response: + arbiters = set([ + _partition_node(h) for h in response["arbiters"]]) + if "hosts" in response: + hosts.update([_partition_node(h) + for h in response["hosts"]]) + if "passives" in response: + hosts.update([_partition_node(h) + for h in response["passives"]]) + + # Start off the new 'members' dict with this member + # but don't add seed list members. + if node in hosts: + members[node] = new_member + if response['ismaster']: + writer = node + + except (ConnectionFailure, socket.error) as why: + if member: + member.pool.discard_socket(sock_info) + errors.append("%s:%d: %s" % (node[0], node[1], str(why))) + if hosts: + break + else: + if errors: + raise AutoReconnect(', '.join(errors)) + raise ConfigurationError('No suitable hosts found') + + # Ensure we have a pool for each member, and find the primary. + for host in hosts: + if host in members: + # This member was the first we connected to, in the loop above. + continue + + member, sock_info = rs_state.get(host), None + try: + if member: + sock_info = self.__socket(member, force=True) + res, ping_time = self.__simple_command( + sock_info, 'admin', {'ismaster': 1}) + member.pool.maybe_return_socket(sock_info) + new_member = member.clone_with(res, ping_time) + else: + res, connection_pool, ping_time = self.__is_master(host) + new_member = Member( + host, connection_pool, res, MovingAverage([ping_time]), + True) + + members[host] = new_member + + except (ConnectionFailure, socket.error): + if member: + member.pool.discard_socket(sock_info) + continue + + if res['ismaster']: + writer = host + + if writer == rs_state.writer: + threadlocal = self.__rs_state.threadlocal + else: + # We unpin threads from members if the primary has changed, since + # no monotonic consistency can be promised now anyway. + threadlocal = self.__make_threadlocal() + + # Replace old state with new. + self.__rs_state = RSState(threadlocal, members, arbiters, writer) + + def __find_primary(self): + """Returns a connection to the primary of this replica set, + if one exists, or raises AutoReconnect. + """ + primary = self.__rs_state.primary_member + if primary: + return primary + + # We had a failover. + self.__schedule_refresh(sync=True) + + # Try again. This time copy the RSState reference so we're guaranteed + # primary_member and error_message are from the same state. + rs_state = self.__rs_state + if rs_state.primary_member: + return rs_state.primary_member + + # Couldn't find the primary. + raise AutoReconnect(rs_state.error_message) + + def __socket(self, member, force=False): + """Get a SocketInfo from the pool. + """ + if self.auto_start_request and not self.in_request(): + self.start_request() + + sock_info = member.pool.get_socket(force=force) + + try: + self.__check_auth(sock_info) + except OperationFailure: + member.pool.maybe_return_socket(sock_info) + raise + return sock_info + + def _ensure_connected(self, sync=False): + """Ensure this client instance is connected to a primary. + """ + # This may be the first time we're connecting to the set. + if self.__monitor and not self.__monitor.started: + try: + self.__monitor.start() + # Minor race condition. It's possible that two (or more) + # threads could call monitor.start() consecutively. Just pass. + except RuntimeError: + pass + if sync: + rs_state = self.__rs_state + if not rs_state.primary_member: + self.__schedule_refresh(sync) + + def disconnect(self): + """Disconnect from the replica set primary, unpin all members, and + refresh our view of the replica set. + """ + rs_state = self.__rs_state + if rs_state.primary_member: + rs_state.primary_member.pool.reset() + + threadlocal = self.__make_threadlocal() + self.__rs_state = rs_state.clone_without_writer(threadlocal) + self.__schedule_refresh() + + def close(self): + """Close this client instance. + + This method first terminates the replica set monitor, then disconnects + from all members of the replica set. + + .. warning:: This method stops the replica set monitor task. The + replica set monitor is required to properly handle replica set + configuration changes, including a failure of the primary. + Once :meth:`~close` is called this client instance must not be reused. + + .. versionchanged:: 2.2.1 + The :meth:`close` method now terminates the replica set monitor. + """ + if self.__monitor: + self.__monitor.shutdown() + # Use a reasonable timeout. + self.__monitor.join(1.0) + self.__monitor = None + + self.__rs_state = RSState(self.__make_threadlocal()) + + def alive(self): + """Return ``False`` if there has been an error communicating with the + primary, else ``True``. + + This method attempts to check the status of the primary with minimal + I/O. The current thread / greenlet retrieves a socket (its request + socket if it's in a request, or a random idle socket if it's not in a + request) from the primary's connection pool and checks whether calling + select_ on it raises an error. If there are currently no idle sockets, + or if there is no known primary, :meth:`alive` will attempt to actually + find and connect to the primary. + + A more certain way to determine primary availability is to ping it:: + + client.admin.command('ping') + + .. _select: http://docs.python.org/2/library/select.html#select.select + """ + # In the common case, a socket is available and was used recently, so + # calling select() on it is a reasonable attempt to see if the OS has + # reported an error. Note this can be wasteful: __socket implicitly + # calls select() if the socket hasn't been checked in the last second, + # or it may create a new socket, in which case calling select() is + # redundant. + member, sock_info = None, None + try: + try: + member = self.__find_primary() + sock_info = self.__socket(member) + return not pool._closed(sock_info.sock) + except (socket.error, ConnectionFailure): + return False + finally: + if member and sock_info: + member.pool.maybe_return_socket(sock_info) + + def __check_response_to_last_error(self, response): + """Check a response to a lastError message for errors. + + `response` is a byte string representing a response to the message. + If it represents an error response we raise OperationFailure. + + Return the response as a document. + """ + response = helpers._unpack_response(response) + + assert response["number_returned"] == 1 + error = response["data"][0] + + helpers._check_command_response(error, self.disconnect) + + error_msg = error.get("err", "") + if error_msg is None: + return error + if error_msg.startswith("not master"): + self.disconnect() + raise AutoReconnect(error_msg) + + if "code" in error: + if error["code"] in (11000, 11001, 12582): + raise DuplicateKeyError(error["err"], error["code"]) + else: + raise OperationFailure(error["err"], error["code"]) + else: + raise OperationFailure(error["err"]) + + def __recv_data(self, length, sock_info): + """Lowest level receive operation. + + Takes length to receive and repeatedly calls recv until able to + return a buffer of that length, raising ConnectionFailure on error. + """ + message = EMPTY + while length: + chunk = sock_info.sock.recv(length) + if chunk == EMPTY: + raise ConnectionFailure("connection closed") + length -= len(chunk) + message += chunk + return message + + def __recv_msg(self, operation, rqst_id, sock): + """Receive a message in response to `rqst_id` on `sock`. + + Returns the response data with the header removed. + """ + header = self.__recv_data(16, sock) + length = struct.unpack(" max_size: + raise InvalidDocument("BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % + (max_doc_size, max_size)) + return (request_id, data) + # get_more and kill_cursors messages + # don't include BSON documents. + return msg + + def _send_message(self, msg, + with_last_error=False, _connection_to_use=None): + """Say something to Mongo. + + Raises ConnectionFailure if the message cannot be sent. Raises + OperationFailure if `with_last_error` is ``True`` and the + response to the getLastError call returns an error. Return the + response from lastError, or ``None`` if `with_last_error` is + ``False``. + + :Parameters: + - `msg`: message to send + - `with_last_error`: check getLastError status after sending the + message + """ + self._ensure_connected() + + if _connection_to_use in (None, -1): + member = self.__find_primary() + else: + member = self.__rs_state.get(_connection_to_use) + + sock_info = None + try: + try: + sock_info = self.__socket(member) + rqst_id, data = self.__check_bson_size( + msg, member.max_bson_size) + + sock_info.sock.sendall(data) + # Safe mode. We pack the message together with a lastError + # message and send both. We then get the response (to the + # lastError) and raise OperationFailure if it is an error + # response. + rv = None + if with_last_error: + response = self.__recv_msg(1, rqst_id, sock_info) + rv = self.__check_response_to_last_error(response) + return rv + except OperationFailure: + raise + except(ConnectionFailure, socket.error) as why: + member.pool.discard_socket(sock_info) + if _connection_to_use in (None, -1): + self.disconnect() + raise AutoReconnect(str(why)) + except: + sock_info.close() + raise + finally: + member.pool.maybe_return_socket(sock_info) + + def __send_and_receive(self, member, msg, **kwargs): + """Send a message on the given socket and return the response data. + + Can raise socket.error. + """ + sock_info = None + exhaust = kwargs.get('exhaust') + rqst_id, data = self.__check_bson_size(msg, member.max_bson_size) + try: + sock_info = self.__socket(member) + + if not exhaust and "network_timeout" in kwargs: + sock_info.sock.settimeout(kwargs['network_timeout']) + + sock_info.sock.sendall(data) + response = self.__recv_msg(1, rqst_id, sock_info) + + if not exhaust: + if "network_timeout" in kwargs: + sock_info.sock.settimeout(self.__net_timeout) + member.pool.maybe_return_socket(sock_info) + + return response, sock_info, member.pool + except: + if sock_info is not None: + sock_info.close() + member.pool.maybe_return_socket(sock_info) + raise + + def __try_read(self, member, msg, **kwargs): + """Attempt a read from a member; on failure mark the member "down" and + wake up the monitor thread to refresh as soon as possible. + """ + try: + return self.__send_and_receive(member, msg, **kwargs) + except socket.timeout as e: + # Could be one slow query, don't refresh. + host, port = member.host + raise AutoReconnect("%s:%d: %s" % (host, port, e)) + except (socket.error, ConnectionFailure) as why: + # Try to replace our RSState with a clone where this member is + # marked "down", to reduce exceptions on other threads, or repeated + # exceptions on this thread. We accept that there's a race + # condition (another thread could be replacing our state with a + # different version concurrently) but this approach is simple and + # lock-free. + self.__rs_state = self.__rs_state.clone_with_host_down( + member.host, str(why)) + + self.__schedule_refresh() + host, port = member.host + raise AutoReconnect("%s:%d: %s" % (host, port, why)) + + def _send_message_with_response(self, msg, _connection_to_use=None, + _must_use_master=False, **kwargs): + """Send a message to Mongo and return the response. + + Sends the given message and returns (host used, response). + + :Parameters: + - `msg`: (request_id, data) pair making up the message to send + - `_connection_to_use`: Optional (host, port) of member for message, + used by Cursor for getMore and killCursors messages. + - `_must_use_master`: If True, send to primary. + """ + self._ensure_connected() + + rs_state = self.__rs_state + tag_sets = kwargs.get('tag_sets', [{}]) + mode = kwargs.get('read_preference', ReadPreference.PRIMARY) + if _must_use_master: + mode = ReadPreference.PRIMARY + tag_sets = [{}] + + if not rs_state.primary_member: + # Primary was down last we checked. Start a refresh if one is not + # already in progress. If caller requested the primary, wait to + # see if it's up, otherwise continue with known-good members. + sync = (mode == ReadPreference.PRIMARY) + self.__schedule_refresh(sync=sync) + rs_state = self.__rs_state + + latency = kwargs.get( + 'secondary_acceptable_latency_ms', + self.secondary_acceptable_latency_ms) + + try: + if _connection_to_use is not None: + if _connection_to_use == -1: + member = rs_state.primary_member + error_message = rs_state.error_message + else: + member = rs_state.get(_connection_to_use) + error_message = '%s:%s not available' % _connection_to_use + + if not member: + raise AutoReconnect(error_message) + + return member.pool.pair, self.__try_read( + member, msg, **kwargs) + except AutoReconnect: + if _connection_to_use in (-1, rs_state.writer): + # Primary's down. Refresh. + self.disconnect() + raise + + # To provide some monotonic consistency, we use the same member as + # long as this thread is in a request and all reads use the same + # mode, tags, and latency. The member gets unpinned if pref changes, + # if member changes state, if we detect a failover, or if this thread + # calls end_request(). + errors = [] + + pinned_host = rs_state.pinned_host + pinned_member = rs_state.get(pinned_host) + if (pinned_member + and pinned_member.matches_mode(mode) + and pinned_member.matches_tag_sets(tag_sets) # TODO: REMOVE? + and rs_state.keep_pinned_host(mode, tag_sets, latency)): + try: + return ( + pinned_member.host, + self.__try_read(pinned_member, msg, **kwargs)) + except AutoReconnect as why: + if _must_use_master or mode == ReadPreference.PRIMARY: + self.disconnect() + raise + else: + errors.append(str(why)) + + # No pinned member, or pinned member down or doesn't match read pref + rs_state.unpin_host() + + members = list(rs_state.members) + while len(errors) < MAX_RETRY: + member = select_member( + members=members, + mode=mode, + tag_sets=tag_sets, + latency=latency) + + if not member: + # Ran out of members to try + break + + try: + # Sets member.up False on failure, so select_member won't try + # it again. + response = self.__try_read(member, msg, **kwargs) + + # Success + if self.in_request(): + # Keep reading from this member in this thread / greenlet + # unless read preference changes + rs_state.pin_host(member.host, mode, tag_sets, latency) + return member.host, response + except AutoReconnect as why: + errors.append(str(why)) + members.remove(member) + + # Ran out of tries + if mode == ReadPreference.PRIMARY: + msg = "No replica set primary available for query" + elif mode == ReadPreference.SECONDARY: + msg = "No replica set secondary available for query" + else: + msg = "No replica set members available for query" + + msg += " with ReadPreference %s" % modes[mode] + + if tag_sets != [{}]: + msg += " and tags " + repr(tag_sets) + + raise AutoReconnect(msg, errors) + + def _exhaust_next(self, sock_info): + """Used with exhaust cursors to get the next batch off the socket. + """ + return self.__recv_msg(1, None, sock_info) + + def start_request(self): + """Ensure the current thread or greenlet always uses the same socket + until it calls :meth:`end_request`. For + :class:`~pymongo.read_preferences.ReadPreference` PRIMARY, + auto_start_request=True ensures consistent reads, even if you read + after an unacknowledged write. For read preferences other than PRIMARY, + there are no consistency guarantees. + + In Python 2.6 and above, or in Python 2.5 with + "from __future__ import with_statement", :meth:`start_request` can be + used as a context manager: + + >>> client = pymongo.MongoReplicaSetClient() + >>> db = client.test + >>> _id = db.test_collection.insert({}) + >>> with client.start_request(): + ... for i in range(100): + ... db.test_collection.update({'_id': _id}, {'$set': {'i':i}}) + ... + ... # Definitely read the document after the final update completes + ... print(db.test_collection.find({'_id': _id})) + + .. versionadded:: 2.2 + The :class:`~pymongo.pool.Request` return value. + :meth:`start_request` previously returned None + """ + # We increment our request counter's thread- or greenlet-local value + # for every call to start_request; however, we only call each pool's + # start_request once to start a request, and call each pool's + # end_request once to end it. We don't let pools' request counters + # exceed 1. This keeps things sane when we create and delete pools + # within a request. + if 1 == self.__request_counter.inc(): + for member in self.__rs_state.members: + member.pool.start_request() + + return pool.Request(self) + + def in_request(self): + """True if :meth:`start_request` has been called, but not + :meth:`end_request`, or if `auto_start_request` is True and + :meth:`end_request` has not been called in this thread or greenlet. + """ + return bool(self.__request_counter.get()) + + def end_request(self): + """Undo :meth:`start_request` and allow this thread's connections to + replica set members to return to the pool. + + Calling :meth:`end_request` allows the :class:`~socket.socket` that has + been reserved for this thread by :meth:`start_request` to be returned + to the pool. Other threads will then be able to re-use that + :class:`~socket.socket`. If your application uses many threads, or has + long-running threads that infrequently perform MongoDB operations, then + judicious use of this method can lead to performance gains. Care should + be taken, however, to make sure that :meth:`end_request` is not called + in the middle of a sequence of operations in which ordering is + important. This could lead to unexpected results. + """ + rs_state = self.__rs_state + if 0 == self.__request_counter.dec(): + for member in rs_state.members: + # No effect if not in a request + member.pool.end_request() + + rs_state.unpin_host() + + def __eq__(self, other): + # XXX: Implement this? + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + def __repr__(self): + return "MongoReplicaSetClient(%r)" % (["%s:%d" % n + for n in self.hosts],) + + def __getattr__(self, name): + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :Parameters: + - `name`: the name of the database to get + """ + return database.Database(self, name) + + def __getitem__(self, name): + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :Parameters: + - `name`: the name of the database to get + """ + return self.__getattr__(name) + + def close_cursor(self, cursor_id, _conn_id): + """Close a single database cursor. + + Raises :class:`TypeError` if `cursor_id` is not an instance of + ``(int, long)``. What closing the cursor actually means + depends on this client's cursor manager. + + :Parameters: + - `cursor_id`: id of cursor to close + """ + if not isinstance(cursor_id, int): + raise TypeError("cursor_id must be an instance of (int, long)") + + self._send_message(message.kill_cursors([cursor_id]), + _connection_to_use=_conn_id) + + def server_info(self): + """Get information about the MongoDB primary we're connected to. + """ + return self.admin.command("buildinfo") + + def database_names(self): + """Get a list of the names of all databases on the connected server. + """ + return [db["name"] for db in + self.admin.command("listDatabases")["databases"]] + + def drop_database(self, name_or_database): + """Drop a database. + + Raises :class:`TypeError` if `name_or_database` is not an instance of + :class:`basestring` (:class:`str` in python 3) or Database + + :Parameters: + - `name_or_database`: the name of a database to drop, or a + :class:`~pymongo.database.Database` instance representing the + database to drop + """ + name = name_or_database + if isinstance(name, database.Database): + name = name.name + + if not isinstance(name, str): + raise TypeError("name_or_database must be an instance of " + "%s or Database" % (str.__name__,)) + + self._purge_index(name) + self[name].command("dropDatabase") + + def copy_database(self, from_name, to_name, + from_host=None, username=None, password=None): + """Copy a database, potentially from another host. + + Raises :class:`TypeError` if `from_name` or `to_name` is not + an instance of :class:`basestring` (:class:`str` in python 3). + Raises :class:`~pymongo.errors.InvalidName` if `to_name` is + not a valid database name. + + If `from_host` is ``None`` the current host is used as the + source. Otherwise the database is copied from `from_host`. + + If the source database requires authentication, `username` and + `password` must be specified. + + :Parameters: + - `from_name`: the name of the source database + - `to_name`: the name of the target database + - `from_host` (optional): host name to copy from + - `username` (optional): username for source database + - `password` (optional): password for source database + + .. note:: Specifying `username` and `password` requires server + version **>= 1.3.3+**. + """ + if not isinstance(from_name, str): + raise TypeError("from_name must be an instance " + "of %s" % (str.__name__,)) + if not isinstance(to_name, str): + raise TypeError("to_name must be an instance " + "of %s" % (str.__name__,)) + + database._check_name(to_name) + + command = {"fromdb": from_name, "todb": to_name} + + if from_host is not None: + command["fromhost"] = from_host + + try: + self.start_request() + + if username is not None: + nonce = self.admin.command("copydbgetnonce", + fromhost=from_host)["nonce"] + command["username"] = username + command["nonce"] = nonce + command["key"] = auth._auth_key(nonce, username, password) + + return self.admin.command("copydb", **command) + finally: + self.end_request() + + def get_default_database(self): + """Get the database named in the MongoDB connection URI. + + >>> uri = 'mongodb://host/my_database' + >>> client = MongoReplicaSetClient(uri) + >>> db = client.get_default_database() + >>> assert db.name == 'my_database' + + Useful in scripts where you want to choose which database to use + based only on the URI in a configuration file. + """ + if self.__default_database_name is None: + raise ConfigurationError('No default database defined') + + return self[self.__default_database_name] + diff --git a/asyncio_mongo/_pymongo/pool.py b/asyncio_mongo/_pymongo/pool.py new file mode 100644 index 0000000..bc56312 --- /dev/null +++ b/asyncio_mongo/_pymongo/pool.py @@ -0,0 +1,555 @@ +# Copyright 2011-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import os +import socket +import sys +import time +import threading +import weakref + +from asyncio_mongo._pymongo import thread_util +from asyncio_mongo._pymongo.common import HAS_SSL +from asyncio_mongo._pymongo.errors import ConnectionFailure, ConfigurationError + +try: + from ssl import match_hostname +except ImportError: + from asyncio_mongo._pymongo.ssl_match_hostname import match_hostname + +if HAS_SSL: + import ssl + +if sys.platform.startswith('java'): + from select import cpython_compatible_select as select +else: + from select import select + + +NO_REQUEST = None +NO_SOCKET_YET = -1 + + +def _closed(sock): + """Return True if we know socket has been closed, False otherwise. + """ + try: + rd, _, _ = select([sock], [], [], 0) + # Any exception here is equally bad (select.error, ValueError, etc.). + except: + return True + return len(rd) > 0 + + +class SocketInfo(object): + """Store a socket with some metadata + """ + def __init__(self, sock, pool_id, host=None): + self.sock = sock + self.host = host + self.authset = set() + self.closed = False + self.last_checkout = time.time() + self.forced = False + + # The pool's pool_id changes with each reset() so we can close sockets + # created before the last reset. + self.pool_id = pool_id + + def close(self): + self.closed = True + # Avoid exceptions on interpreter shutdown. + try: + self.sock.close() + except: + pass + + def __eq__(self, other): + # Need to check if other is NO_REQUEST or NO_SOCKET_YET, and then check + # if its sock is the same as ours + return hasattr(other, 'sock') and self.sock == other.sock + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(self.sock) + + def __repr__(self): + return "SocketInfo(%s)%s at %s" % ( + repr(self.sock), + self.closed and " CLOSED" or "", + id(self) + ) + + +# Do *not* explicitly inherit from object or Jython won't call __del__ +# http://bugs.jython.org/issue1057 +class Pool: + def __init__(self, pair, max_size, net_timeout, conn_timeout, use_ssl, + use_greenlets, ssl_keyfile=None, ssl_certfile=None, + ssl_cert_reqs=None, ssl_ca_certs=None, + wait_queue_timeout=None, wait_queue_multiple=None): + """ + :Parameters: + - `pair`: a (hostname, port) tuple + - `max_size`: The maximum number of open sockets. Calls to + `get_socket` will block if this is set, this pool has opened + `max_size` sockets, and there are none idle. Set to `None` to + disable. + - `net_timeout`: timeout in seconds for operations on open connection + - `conn_timeout`: timeout in seconds for establishing connection + - `use_ssl`: bool, if True use an encrypted connection + - `use_greenlets`: bool, if True then start_request() assigns a + socket to the current greenlet - otherwise it is assigned to the + current thread + - `ssl_keyfile`: The private keyfile used to identify the local + connection against mongod. If included with the ``certfile` then + only the ``ssl_certfile`` is needed. Implies ``ssl=True``. + - `ssl_certfile`: The certificate file used to identify the local + connection against mongod. Implies ``ssl=True``. + - `ssl_cert_reqs`: Specifies whether a certificate is required from + the other side of the connection, and whether it will be validated + if provided. It must be one of the three values ``ssl.CERT_NONE`` + (certificates ignored), ``ssl.CERT_OPTIONAL`` + (not required, but validated if provided), or ``ssl.CERT_REQUIRED`` + (required and validated). If the value of this parameter is not + ``ssl.CERT_NONE``, then the ``ssl_ca_certs`` parameter must point + to a file of CA certificates. Implies ``ssl=True``. + - `ssl_ca_certs`: The ca_certs file contains a set of concatenated + "certification authority" certificates, which are used to validate + certificates passed from the other end of the connection. + Implies ``ssl=True``. + - `wait_queue_timeout`: (integer) How long (in seconds) a + thread will wait for a socket from the pool if the pool has no + free sockets. + - `wait_queue_multiple`: (integer) Multiplied by max_pool_size to give + the number of threads allowed to wait for a socket at one time. + """ + # Only check a socket's health with _closed() every once in a while. + # Can override for testing: 0 to always check, None to never check. + self._check_interval_seconds = 1 + + self.sockets = set() + self.lock = threading.Lock() + + # Keep track of resets, so we notice sockets created before the most + # recent reset and close them. + self.pool_id = 0 + self.pid = os.getpid() + self.pair = pair + self.max_size = max_size + self.net_timeout = net_timeout + self.conn_timeout = conn_timeout + self.wait_queue_timeout = wait_queue_timeout + self.wait_queue_multiple = wait_queue_multiple + self.use_ssl = use_ssl + self.ssl_keyfile = ssl_keyfile + self.ssl_certfile = ssl_certfile + self.ssl_cert_reqs = ssl_cert_reqs + self.ssl_ca_certs = ssl_ca_certs + + if HAS_SSL and use_ssl and not ssl_cert_reqs: + self.ssl_cert_reqs = ssl.CERT_NONE + + # Map self._ident.get() -> request socket + self._tid_to_sock = {} + + if use_greenlets and not thread_util.have_gevent: + raise ConfigurationError( + "The Gevent module is not available. " + "Install the gevent package from PyPI." + ) + + self._ident = thread_util.create_ident(use_greenlets) + + # Count the number of calls to start_request() per thread or greenlet + self._request_counter = thread_util.Counter(use_greenlets) + + if self.wait_queue_multiple is None or self.max_size is None: + max_waiters = None + else: + max_waiters = self.max_size * self.wait_queue_multiple + + self._socket_semaphore = thread_util.create_semaphore( + self.max_size, max_waiters, use_greenlets) + + def reset(self): + # Ignore this race condition -- if many threads are resetting at once, + # the pool_id will definitely change, which is all we care about. + self.pool_id += 1 + self.pid = os.getpid() + + sockets = None + try: + # Swapping variables is not atomic. We need to ensure no other + # thread is modifying self.sockets, or replacing it, in this + # critical section. + self.lock.acquire() + sockets, self.sockets = self.sockets, set() + finally: + self.lock.release() + + for sock_info in sockets: + sock_info.close() + + def create_connection(self, pair): + """Connect to *pair* and return the socket object. + + This is a modified version of create_connection from + CPython >=2.6. + """ + host, port = pair or self.pair + + # Check if dealing with a unix domain socket + if host.endswith('.sock'): + if not hasattr(socket, "AF_UNIX"): + raise ConnectionFailure("UNIX-sockets are not supported " + "on this system") + sock = socket.socket(socket.AF_UNIX) + try: + sock.connect(host) + return sock + except socket.error as e: + if sock is not None: + sock.close() + raise e + + # Don't try IPv6 if we don't support it. Also skip it if host + # is 'localhost' (::1 is fine). Avoids slow connect issues + # like PYTHON-356. + family = socket.AF_INET + if socket.has_ipv6 and host != 'localhost': + family = socket.AF_UNSPEC + + err = None + for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + af, socktype, proto, dummy, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + sock.settimeout(self.conn_timeout or 20.0) + sock.connect(sa) + return sock + except socket.error as e: + err = e + if sock is not None: + sock.close() + + if err is not None: + raise err + else: + # This likely means we tried to connect to an IPv6 only + # host with an OS/kernel or Python interpreter that doesn't + # support IPv6. The test case is Jython2.5.1 which doesn't + # support IPv6 at all. + raise socket.error('getaddrinfo failed') + + def connect(self, pair): + """Connect to Mongo and return a new (connected) socket. Note that the + pool does not keep a reference to the socket -- you must call + return_socket() when you're done with it. + """ + sock = self.create_connection(pair) + hostname = (pair or self.pair)[0] + + if self.use_ssl: + try: + sock = ssl.wrap_socket(sock, + certfile=self.ssl_certfile, + keyfile=self.ssl_keyfile, + ca_certs=self.ssl_ca_certs, + cert_reqs=self.ssl_cert_reqs) + if self.ssl_cert_reqs: + match_hostname(sock.getpeercert(), hostname) + + except ssl.SSLError: + sock.close() + raise ConnectionFailure("SSL handshake failed. MongoDB may " + "not be configured with SSL support.") + + sock.settimeout(self.net_timeout) + return SocketInfo(sock, self.pool_id, hostname) + + def get_socket(self, pair=None, force=False): + """Get a socket from the pool. + + Returns a :class:`SocketInfo` object wrapping a connected + :class:`socket.socket`, and a bool saying whether the socket was from + the pool or freshly created. + + :Parameters: + - `pair`: optional (hostname, port) tuple + - `force`: optional boolean, forces a connection to be returned + without blocking, even if `max_size` has been reached. + """ + # We use the pid here to avoid issues with fork / multiprocessing. + # See test.test_client:TestClient.test_fork for an example of + # what could go wrong otherwise + if self.pid != os.getpid(): + self.reset() + + # Have we opened a socket for this request? + req_state = self._get_request_state() + if req_state not in (NO_SOCKET_YET, NO_REQUEST): + # There's a socket for this request, check it and return it + checked_sock = self._check(req_state, pair) + if checked_sock != req_state: + self._set_request_state(checked_sock) + + checked_sock.last_checkout = time.time() + return checked_sock + + forced = False + # We're not in a request, just get any free socket or create one + if force: + # If we're doing an internal operation, attempt to play nicely with + # max_size, but if there is no open "slot" force the connection + # and mark it as forced so we don't release the semaphore without + # having acquired it for this socket. + if not self._socket_semaphore.acquire(False): + forced = True + elif not self._socket_semaphore.acquire(True, self.wait_queue_timeout): + self._raise_wait_queue_timeout() + + # We've now acquired the semaphore and must release it on error. + try: + sock_info, from_pool = None, None + try: + try: + # set.pop() isn't atomic in Jython less than 2.7, see + # http://bugs.jython.org/issue1854 + self.lock.acquire() + sock_info, from_pool = self.sockets.pop(), True + finally: + self.lock.release() + except KeyError: + sock_info, from_pool = self.connect(pair), False + + if from_pool: + sock_info = self._check(sock_info, pair) + + sock_info.forced = forced + + if req_state == NO_SOCKET_YET: + # start_request has been called but we haven't assigned a + # socket to the request yet. Let's use this socket for this + # request until end_request. + self._set_request_state(sock_info) + except: + if not forced: + self._socket_semaphore.release() + raise + + sock_info.last_checkout = time.time() + return sock_info + + def start_request(self): + if self._get_request_state() == NO_REQUEST: + # Add a placeholder value so we know we're in a request, but we + # have no socket assigned to the request yet. + self._set_request_state(NO_SOCKET_YET) + + self._request_counter.inc() + + def in_request(self): + return bool(self._request_counter.get()) + + def end_request(self): + # Check if start_request has ever been called in this thread / greenlet + count = self._request_counter.get() + if count: + self._request_counter.dec() + if count == 1: + # End request + sock_info = self._get_request_state() + self._set_request_state(NO_REQUEST) + if sock_info not in (NO_REQUEST, NO_SOCKET_YET): + self._return_socket(sock_info) + + def discard_socket(self, sock_info): + """Close and discard the active socket. + """ + if sock_info not in (NO_REQUEST, NO_SOCKET_YET): + sock_info.close() + + if sock_info == self._get_request_state(): + # Discarding request socket; prepare to use a new request + # socket on next get_socket(). + self._set_request_state(NO_SOCKET_YET) + + def maybe_return_socket(self, sock_info): + """Return the socket to the pool unless it's the request socket. + """ + # These sentinel values should only be used internally. + assert sock_info not in (NO_REQUEST, NO_SOCKET_YET) + + if self.pid != os.getpid(): + if not sock_info.forced: + self._socket_semaphore.release() + self.reset() + else: + if sock_info.closed: + if sock_info.forced: + sock_info.forced = False + elif sock_info != self._get_request_state(): + self._socket_semaphore.release() + return + + if sock_info != self._get_request_state(): + self._return_socket(sock_info) + + def _return_socket(self, sock_info): + """Return socket to the pool. If pool is full the socket is discarded. + """ + try: + self.lock.acquire() + too_many_sockets = (self.max_size is not None + and len(self.sockets) >= self.max_size) + + if not too_many_sockets and sock_info.pool_id == self.pool_id: + self.sockets.add(sock_info) + else: + sock_info.close() + finally: + self.lock.release() + + if sock_info.forced: + sock_info.forced = False + else: + self._socket_semaphore.release() + + def _check(self, sock_info, pair): + """This side-effecty function checks if this pool has been reset since + the last time this socket was used, or if the socket has been closed by + some external network error, and if so, attempts to create a new socket. + If this connection attempt fails we reset the pool and reraise the + error. + + Checking sockets lets us avoid seeing *some* + :class:`~pymongo.errors.AutoReconnect` exceptions on server + hiccups, etc. We only do this if it's been > 1 second since + the last socket checkout, to keep performance reasonable - we + can't avoid AutoReconnects completely anyway. + """ + error = False + + # How long since socket was last checked out. + age = time.time() - sock_info.last_checkout + + if sock_info.closed: + error = True + + elif self.pool_id != sock_info.pool_id: + sock_info.close() + error = True + + elif (self._check_interval_seconds is not None + and ( + 0 == self._check_interval_seconds + or age > self._check_interval_seconds)): + if _closed(sock_info.sock): + sock_info.close() + error = True + + if not error: + return sock_info + else: + try: + return self.connect(pair) + except socket.error: + self.reset() + raise + + def _set_request_state(self, sock_info): + ident = self._ident + tid = ident.get() + + if sock_info == NO_REQUEST: + # Ending a request + ident.unwatch(tid) + self._tid_to_sock.pop(tid, None) + else: + self._tid_to_sock[tid] = sock_info + + if not ident.watching(): + # Closure over tid, poolref, and ident. Don't refer directly to + # self, otherwise there's a cycle. + + # Do not access threadlocals in this function, or any + # function it calls! In the case of the Pool subclass and + # mod_wsgi 2.x, on_thread_died() is triggered when mod_wsgi + # calls PyThreadState_Clear(), which deferences the + # ThreadVigil and triggers the weakref callback. Accessing + # thread locals in this function, while PyThreadState_Clear() + # is in progress can cause leaks, see PYTHON-353. + poolref = weakref.ref(self) + + def on_thread_died(ref): + try: + ident.unwatch(tid) + pool = poolref() + if pool: + # End the request + request_sock = pool._tid_to_sock.pop(tid, None) + + # Was thread ever assigned a socket before it died? + if request_sock not in (NO_REQUEST, NO_SOCKET_YET): + pool._return_socket(request_sock) + except: + # Random exceptions on interpreter shutdown. + pass + + ident.watch(on_thread_died) + + def _get_request_state(self): + tid = self._ident.get() + return self._tid_to_sock.get(tid, NO_REQUEST) + + def _raise_wait_queue_timeout(self): + raise ConnectionFailure( + 'Timed out waiting for socket from pool with max_size %r and' + ' wait_queue_timeout %r' % ( + self.max_size, self.wait_queue_timeout)) + + def __del__(self): + # Avoid ResourceWarnings in Python 3 + for sock_info in self.sockets: + sock_info.close() + + for request_sock in list(self._tid_to_sock.values()): + if request_sock not in (NO_REQUEST, NO_SOCKET_YET): + request_sock.close() + + +class Request(object): + """ + A context manager returned by :meth:`start_request`, so you can do + `with client.start_request(): do_something()` in Python 2.5+. + """ + def __init__(self, connection): + self.connection = connection + + def end(self): + self.connection.end_request() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end() + # Returning False means, "Don't suppress exceptions if any were + # thrown within the block" + return False diff --git a/asyncio_mongo/_pymongo/read_preferences.py b/asyncio_mongo/_pymongo/read_preferences.py new file mode 100644 index 0000000..697f6d6 --- /dev/null +++ b/asyncio_mongo/_pymongo/read_preferences.py @@ -0,0 +1,211 @@ +# Copyright 2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License", +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for choosing which member of a replica set to read from.""" + +import random + +from asyncio_mongo._pymongo.errors import ConfigurationError + + +class ReadPreference: + """An enum that defines the read preference modes supported by PyMongo. + Used in three cases: + + :class:`~pymongo.mongo_client.MongoClient` connected to a single host: + + * `PRIMARY`: Queries are allowed if the host is standalone or the replica + set primary. + * All other modes allow queries to standalone servers, to the primary, or + to secondaries. + + :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a + sharded cluster of replica sets: + + * `PRIMARY`: Queries are sent to the primary of a shard. + * `PRIMARY_PREFERRED`: Queries are sent to the primary if available, + otherwise a secondary. + * `SECONDARY`: Queries are distributed among shard secondaries. An error + is raised if no secondaries are available. + * `SECONDARY_PREFERRED`: Queries are distributed among shard secondaries, + or the primary if no secondary is available. + * `NEAREST`: Queries are distributed among all members of a shard. + + :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`: + + * `PRIMARY`: Queries are sent to the primary of the replica set. + * `PRIMARY_PREFERRED`: Queries are sent to the primary if available, + otherwise a secondary. + * `SECONDARY`: Queries are distributed among secondaries. An error + is raised if no secondaries are available. + * `SECONDARY_PREFERRED`: Queries are distributed among secondaries, + or the primary if no secondary is available. + * `NEAREST`: Queries are distributed among all members. + """ + + PRIMARY = 0 + PRIMARY_PREFERRED = 1 + SECONDARY = 2 + SECONDARY_ONLY = 2 + SECONDARY_PREFERRED = 3 + NEAREST = 4 + +# For formatting error messages +modes = { + ReadPreference.PRIMARY: 'PRIMARY', + ReadPreference.PRIMARY_PREFERRED: 'PRIMARY_PREFERRED', + ReadPreference.SECONDARY: 'SECONDARY', + ReadPreference.SECONDARY_PREFERRED: 'SECONDARY_PREFERRED', + ReadPreference.NEAREST: 'NEAREST', +} + +_mongos_modes = [ + 'primary', + 'primaryPreferred', + 'secondary', + 'secondaryPreferred', + 'nearest', +] + +def mongos_mode(mode): + return _mongos_modes[mode] + +def mongos_enum(enum): + return _mongos_modes.index(enum) + +def select_primary(members): + for member in members: + if member.is_primary: + if member.up: + return member + else: + return None + + return None + + +def select_member_with_tags(members, tags, secondary_only, latency): + candidates = [] + + for candidate in members: + if not candidate.up: + continue + + if secondary_only and candidate.is_primary: + continue + + if not (candidate.is_primary or candidate.is_secondary): + # In RECOVERING or similar state + continue + + if candidate.matches_tags(tags): + candidates.append(candidate) + + if not candidates: + return None + + # ping_time is in seconds + fastest = min([candidate.get_avg_ping_time() for candidate in candidates]) + near_candidates = [ + candidate for candidate in candidates + if candidate.get_avg_ping_time() - fastest < latency / 1000.] + + return random.choice(near_candidates) + + +def select_member( + members, + mode=ReadPreference.PRIMARY, + tag_sets=None, + latency=15 +): + """Return a Member or None. + """ + if tag_sets is None: + tag_sets = [{}] + + # For brevity + PRIMARY = ReadPreference.PRIMARY + PRIMARY_PREFERRED = ReadPreference.PRIMARY_PREFERRED + SECONDARY = ReadPreference.SECONDARY + SECONDARY_PREFERRED = ReadPreference.SECONDARY_PREFERRED + NEAREST = ReadPreference.NEAREST + + if mode == PRIMARY: + if tag_sets != [{}]: + raise ConfigurationError("PRIMARY cannot be combined with tags") + return select_primary(members) + + elif mode == PRIMARY_PREFERRED: + # Recurse. + candidate_primary = select_member(members, PRIMARY, [{}], latency) + if candidate_primary: + return candidate_primary + else: + return select_member(members, SECONDARY, tag_sets, latency) + + elif mode == SECONDARY: + for tags in tag_sets: + candidate = select_member_with_tags(members, tags, True, latency) + if candidate: + return candidate + + return None + + elif mode == SECONDARY_PREFERRED: + # Recurse. + candidate_secondary = select_member( + members, SECONDARY, tag_sets, latency) + if candidate_secondary: + return candidate_secondary + else: + return select_member(members, PRIMARY, [{}], latency) + + elif mode == NEAREST: + for tags in tag_sets: + candidate = select_member_with_tags(members, tags, False, latency) + if candidate: + return candidate + + # Ran out of tags. + return None + + else: + raise ConfigurationError("Invalid mode %s" % repr(mode)) + + +"""Commands that may be sent to replica-set secondaries, depending on + ReadPreference and tags. All other commands are always run on the primary. +""" +secondary_ok_commands = frozenset([ + "group", "aggregate", "collstats", "dbstats", "count", "distinct", + "geonear", "geosearch", "geowalk", "mapreduce", "getnonce", "authenticate", + "text", +]) + + +class MovingAverage(object): + def __init__(self, samples): + """Immutable structure to track a 5-sample moving average. + """ + self.samples = samples[-5:] + assert self.samples + self.average = sum(self.samples) / float(len(self.samples)) + + def clone_with(self, sample): + """Get a copy of this instance plus a new sample""" + return MovingAverage(self.samples + [sample]) + + def get(self): + return self.average diff --git a/asyncio_mongo/_pymongo/replica_set_connection.py b/asyncio_mongo/_pymongo/replica_set_connection.py new file mode 100644 index 0000000..261eb74 --- /dev/null +++ b/asyncio_mongo/_pymongo/replica_set_connection.py @@ -0,0 +1,222 @@ +# Copyright 2011-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools for connecting to a MongoDB replica set. + +.. warning:: + **DEPRECATED:** Please use :mod:`~pymongo.mongo_replica_set_client` instead. + +.. seealso:: :doc:`/examples/high_availability` for more examples of + how to connect to a replica set. + +To get a :class:`~pymongo.database.Database` instance from a +:class:`ReplicaSetConnection` use either dictionary-style or +attribute-style access: + +.. doctest:: + + >>> from asyncio_mongo._pymongo import ReplicaSetConnection + >>> c = ReplicaSetConnection('localhost:27017', replicaSet='repl0') + >>> c.test_database + Database(ReplicaSetConnection([u'...', u'...']), u'test_database') + >>> c['test_database'] + Database(ReplicaSetConnection([u'...', u'...']), u'test_database') +""" +from asyncio_mongo._pymongo.mongo_replica_set_client import MongoReplicaSetClient +from asyncio_mongo._pymongo.errors import ConfigurationError + + +class ReplicaSetConnection(MongoReplicaSetClient): + """Connection to a MongoDB replica set. + """ + + def __init__(self, hosts_or_uri=None, max_pool_size=None, + document_class=dict, tz_aware=False, **kwargs): + """Create a new connection to a MongoDB replica set. + + .. warning:: + **DEPRECATED:** :class:`ReplicaSetConnection` is deprecated. Please + use :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` + instead + + The resultant connection object has connection-pooling built + in. It also performs auto-reconnection when necessary. If an + operation fails because of a connection error, + :class:`~pymongo.errors.ConnectionFailure` is raised. If + auto-reconnection will be performed, + :class:`~pymongo.errors.AutoReconnect` will be + raised. Application code should handle this exception + (recognizing that the operation failed) and then continue to + execute. + + Raises :class:`~pymongo.errors.ConnectionFailure` if + the connection cannot be made. + + The `hosts_or_uri` parameter can be a full `mongodb URI + `_, in addition to + a string of `host:port` pairs (e.g. 'host1:port1,host2:port2'). + If `hosts_or_uri` is None 'localhost:27017' will be used. + + .. note:: Instances of :class:`~ReplicaSetConnection` start a + background task to monitor the state of the replica set. This allows + it to quickly respond to changes in replica set configuration. + Before discarding an instance of :class:`~ReplicaSetConnection` make + sure you call :meth:`~close` to ensure that the monitor task is + cleanly shut down. + + :Parameters: + - `hosts_or_uri` (optional): A MongoDB URI or string of `host:port` + pairs. If a host is an IPv6 literal it must be enclosed in '[' and + ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for + localhost) + - `max_pool_size` (optional): The maximum number of connections + each pool will open simultaneously. If this is set, operations + will block if there are `max_pool_size` outstanding connections + from the pool. By default the pool size is unlimited. + - `document_class` (optional): default class to use for + documents returned from queries on this connection + - `tz_aware` (optional): if ``True``, + :class:`~datetime.datetime` instances returned as values + in a document by this :class:`ReplicaSetConnection` will be timezone + aware (otherwise they will be naive) + - `replicaSet`: (required) The name of the replica set to connect to. + The driver will verify that each host it connects to is a member of + this replica set. Can be passed as a keyword argument or as a + MongoDB URI option. + + | **Other optional parameters can be passed as keyword arguments:** + + - `host`: For compatibility with connection.Connection. If both + `host` and `hosts_or_uri` are specified `host` takes precedence. + - `port`: For compatibility with connection.Connection. The default + port number to use for hosts. + - `network_timeout`: For compatibility with connection.Connection. + The timeout (in seconds) to use for socket operations - default + is no timeout. If both `network_timeout` and `socketTimeoutMS` are + specified `network_timeout` takes precedence, matching + connection.Connection. + - `socketTimeoutMS`: (integer) How long (in milliseconds) a send or + receive on a socket can take before timing out. + - `connectTimeoutMS`: (integer) How long (in milliseconds) a + connection can take to be opened before timing out. + - `waitQueueTimeoutMS`: (integer) How long (in milliseconds) a + thread will wait for a socket from the pool if the pool has no + free sockets. Defaults to ``None`` (no timeout). + - `waitQueueMultiple`: (integer) Multiplied by max_pool_size to give + the number of threads allowed to wait for a socket at one time. + Defaults to ``None`` (no waiters). + - `auto_start_request`: If ``True`` (the default), each thread that + accesses this :class:`ReplicaSetConnection` has a socket allocated + to it for the thread's lifetime, for each member of the set. For + :class:`~pymongo.read_preferences.ReadPreference` PRIMARY, + auto_start_request=True ensures consistent reads, even if you read + after an unsafe write. For read preferences other than PRIMARY, + there are no consistency guarantees. + - `use_greenlets`: if ``True``, use a background Greenlet instead of + a background thread to monitor state of replica set. Additionally, + :meth:`start_request()` will ensure that the current greenlet uses + the same socket for all operations until :meth:`end_request()`. + `use_greenlets` with ReplicaSetConnection requires `Gevent + `_ to be installed. + + | **Write Concern options:** + + - `safe`: :class:`ReplicaSetConnection` **disables** acknowledgement + of write operations. Use ``safe=True`` to enable write + acknowledgement. + - `w`: (integer or string) Write operations will block until they have + been replicated to the specified number or tagged set of servers. + `w=` always includes the replica set primary (e.g. w=3 means + write to the primary and wait until replicated to **two** + secondaries). Implies safe=True. + - `wtimeout`: (integer) Used in conjunction with `w`. Specify a value + in milliseconds to control how long to wait for write propagation + to complete. If replication does not complete in the given + timeframe, a timeout exception is raised. Implies safe=True. + - `j`: If ``True`` block until write operations have been committed + to the journal. Ignored if the server is running without journaling. + Implies safe=True. + - `fsync`: If ``True`` force the database to fsync all files before + returning. When used with `j` the server awaits the next group + commit before returning. Implies safe=True. + + | **Read preference options:** + + - `slave_okay` or `slaveOk` (deprecated): Use `read_preference` + instead. + - `read_preference`: The read preference for this connection. + See :class:`~pymongo.read_preferences.ReadPreference` for available + - `tag_sets`: Read from replica-set members with these tags. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." :class:`MongoReplicaSetClient` tries each set of + tags in turn until it finds a set of tags with at least one matching + member. + - `secondary_acceptable_latency_ms`: (integer) Any replica-set member + whose ping time is within secondary_acceptable_latency_ms of the + nearest member may accept reads. Default 15 milliseconds. + **Ignored by mongos** and must be configured on the command line. + See the localThreshold_ option for more information. + + | **SSL configuration:** + + - `ssl`: If ``True``, create the connection to the servers using SSL. + - `ssl_keyfile`: The private keyfile used to identify the local + connection against mongod. If included with the ``certfile` then + only the ``ssl_certfile`` is needed. Implies ``ssl=True``. + - `ssl_certfile`: The certificate file used to identify the local + connection against mongod. Implies ``ssl=True``. + - `ssl_cert_reqs`: Specifies whether a certificate is required from + the other side of the connection, and whether it will be validated + if provided. It must be one of the three values ``ssl.CERT_NONE`` + (certificates ignored), ``ssl.CERT_OPTIONAL`` + (not required, but validated if provided), or ``ssl.CERT_REQUIRED`` + (required and validated). If the value of this parameter is not + ``ssl.CERT_NONE``, then the ``ssl_ca_certs`` parameter must point + to a file of CA certificates. Implies ``ssl=True``. + - `ssl_ca_certs`: The ca_certs file contains a set of concatenated + "certification authority" certificates, which are used to validate + certificates passed from the other end of the connection. + Implies ``ssl=True``. + + .. versionchanged:: 2.5 + Added additional ssl options + .. versionchanged:: 2.3 + Added `tag_sets` and `secondary_acceptable_latency_ms` options. + .. versionchanged:: 2.2 + Added `auto_start_request` and `use_greenlets` options. + Added support for `host`, `port`, and `network_timeout` keyword + arguments for compatibility with connection.Connection. + .. versionadded:: 2.1 + + .. _localThreshold: http://docs.mongodb.org/manual/reference/mongos/#cmdoption-mongos--localThreshold + """ + network_timeout = kwargs.pop('network_timeout', None) + if network_timeout is not None: + if (not isinstance(network_timeout, (int, float)) or + network_timeout <= 0): + raise ConfigurationError("network_timeout must " + "be a positive integer") + kwargs['socketTimeoutMS'] = network_timeout * 1000 + + kwargs['auto_start_request'] = kwargs.get('auto_start_request', True) + kwargs['safe'] = kwargs.get('safe', False) + + super(ReplicaSetConnection, self).__init__( + hosts_or_uri, max_pool_size, document_class, tz_aware, **kwargs) + + def __repr__(self): + return "ReplicaSetConnection(%r)" % (["%s:%d" % n + for n in self.hosts],) diff --git a/asyncio_mongo/_pymongo/son_manipulator.py b/asyncio_mongo/_pymongo/son_manipulator.py new file mode 100644 index 0000000..2430253 --- /dev/null +++ b/asyncio_mongo/_pymongo/son_manipulator.py @@ -0,0 +1,177 @@ +# Copyright 2009-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Manipulators that can edit SON objects as they enter and exit a database. + +New manipulators should be defined as subclasses of SONManipulator and can be +installed on a database by calling +`pymongo.database.Database.add_son_manipulator`.""" + +from asyncio_mongo._bson.dbref import DBRef +from asyncio_mongo._bson.objectid import ObjectId +from asyncio_mongo._bson.son import SON + + +class SONManipulator(object): + """A base son manipulator. + + This manipulator just saves and restores objects without changing them. + """ + + def will_copy(self): + """Will this SON manipulator make a copy of the incoming document? + + Derived classes that do need to make a copy should override this + method, returning True instead of False. All non-copying manipulators + will be applied first (so that the user's document will be updated + appropriately), followed by copying manipulators. + """ + return False + + def transform_incoming(self, son, collection): + """Manipulate an incoming SON object. + + :Parameters: + - `son`: the SON object to be inserted into the database + - `collection`: the collection the object is being inserted into + """ + if self.will_copy(): + return SON(son) + return son + + def transform_outgoing(self, son, collection): + """Manipulate an outgoing SON object. + + :Parameters: + - `son`: the SON object being retrieved from the database + - `collection`: the collection this object was stored in + """ + if self.will_copy(): + return SON(son) + return son + + +class ObjectIdInjector(SONManipulator): + """A son manipulator that adds the _id field if it is missing. + """ + + def transform_incoming(self, son, collection): + """Add an _id field if it is missing. + """ + if not "_id" in son: + son["_id"] = ObjectId() + return son + + +# This is now handled during BSON encoding (for performance reasons), +# but I'm keeping this here as a reference for those implementing new +# SONManipulators. +class ObjectIdShuffler(SONManipulator): + """A son manipulator that moves _id to the first position. + """ + + def will_copy(self): + """We need to copy to be sure that we are dealing with SON, not a dict. + """ + return True + + def transform_incoming(self, son, collection): + """Move _id to the front if it's there. + """ + if not "_id" in son: + return son + transformed = SON({"_id": son["_id"]}) + transformed.update(son) + return transformed + + +class NamespaceInjector(SONManipulator): + """A son manipulator that adds the _ns field. + """ + + def transform_incoming(self, son, collection): + """Add the _ns field to the incoming object + """ + son["_ns"] = collection.name + return son + + +class AutoReference(SONManipulator): + """Transparently reference and de-reference already saved embedded objects. + + This manipulator should probably only be used when the NamespaceInjector is + also being used, otherwise it doesn't make too much sense - documents can + only be auto-referenced if they have an *_ns* field. + + NOTE: this will behave poorly if you have a circular reference. + + TODO: this only works for documents that are in the same database. To fix + this we'll need to add a DatabaseInjector that adds *_db* and then make + use of the optional *database* support for DBRefs. + """ + + def __init__(self, db): + self.database = db + + def will_copy(self): + """We need to copy so the user's document doesn't get transformed refs. + """ + return True + + def transform_incoming(self, son, collection): + """Replace embedded documents with DBRefs. + """ + + def transform_value(value): + if isinstance(value, dict): + if "_id" in value and "_ns" in value: + return DBRef(value["_ns"], transform_value(value["_id"])) + else: + return transform_dict(SON(value)) + elif isinstance(value, list): + return [transform_value(v) for v in value] + return value + + def transform_dict(object): + for (key, value) in list(object.items()): + object[key] = transform_value(value) + return object + + return transform_dict(SON(son)) + + def transform_outgoing(self, son, collection): + """Replace DBRefs with embedded documents. + """ + + def transform_value(value): + if isinstance(value, DBRef): + return self.database.dereference(value) + elif isinstance(value, list): + return [transform_value(v) for v in value] + elif isinstance(value, dict): + return transform_dict(SON(value)) + return value + + def transform_dict(object): + for (key, value) in list(object.items()): + object[key] = transform_value(value) + return object + + return transform_dict(SON(son)) + +# TODO make a generic translator for custom types. Take encode, decode, +# should_encode and should_decode functions and just encode and decode where +# necessary. See examples/custom_type.py for where this would be useful. +# Alternatively it could take a should_encode, to_binary, from_binary and +# binary subtype. diff --git a/asyncio_mongo/_pymongo/ssl_match_hostname.py b/asyncio_mongo/_pymongo/ssl_match_hostname.py new file mode 100644 index 0000000..24cdd41 --- /dev/null +++ b/asyncio_mongo/_pymongo/ssl_match_hostname.py @@ -0,0 +1,69 @@ +# Backport of the match_hostname logic introduced in python 3.2 +# http://svn.python.org/projects/python/branches/release32-maint/Lib/ssl.py + +import re + + +class CertificateError(ValueError): + pass + + +def _dnsname_to_pat(dn, max_wildcards=1): + pats = [] + for frag in dn.split(r'.'): + if frag.count('*') > max_wildcards: + # Issue #17980: avoid denials of service by refusing more + # than one wildcard per fragment. A survery of established + # policy among SSL implementations showed it to be a + # reasonable choice. + raise CertificateError( + "too many wildcards in certificate DNS name: " + repr(dn)) + if frag == '*': + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append('[^.]+') + else: + # Otherwise, '*' matches any dotless fragment. + frag = re.escape(frag) + pats.append(frag.replace(r'\*', '[^.]*')) + return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) + + +def match_hostname(cert, hostname): + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 rules + are mostly followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError("empty or no certificate") + dnsnames = [] + san = cert.get('subjectAltName', ()) + for key, value in san: + if key == 'DNS': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if not san: + # The subject is only checked when subjectAltName is empty + for sub in cert.get('subject', ()): + for key, value in sub: + # XXX according to RFC 2818, the most specific Common Name + # must be used. + if key == 'commonName': + if _dnsname_to_pat(value).match(hostname): + return + dnsnames.append(value) + if len(dnsnames) > 1: + raise CertificateError("hostname %r " + "doesn't match either of %s" + % (hostname, ', '.join(map(repr, dnsnames)))) + elif len(dnsnames) == 1: + raise CertificateError("hostname %r " + "doesn't match %r" + % (hostname, dnsnames[0])) + else: + raise CertificateError("no appropriate commonName or " + "subjectAltName fields were found") diff --git a/asyncio_mongo/_pymongo/thread_util.py b/asyncio_mongo/_pymongo/thread_util.py new file mode 100644 index 0000000..f310f53 --- /dev/null +++ b/asyncio_mongo/_pymongo/thread_util.py @@ -0,0 +1,303 @@ +# Copyright 2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities to abstract the differences between threads and greenlets.""" + +import threading +import sys +import weakref +try: + from time import monotonic as _time +except ImportError: + from time import time as _time + +have_gevent = True +try: + import greenlet + + try: + # gevent-1.0rc2 and later. + from gevent.lock import BoundedSemaphore as GeventBoundedSemaphore + except ImportError: + from gevent.coros import BoundedSemaphore as GeventBoundedSemaphore + + from gevent.greenlet import SpawnedLink + +except ImportError: + have_gevent = False + +from asyncio_mongo._pymongo.errors import ExceededMaxWaiters + + +# Do we have to work around http://bugs.python.org/issue1868? +issue1868 = (sys.version_info[:3] <= (2, 7, 0)) + + +class Ident(object): + def __init__(self): + self._refs = {} + + def watching(self): + """Is the current thread or greenlet being watched for death?""" + return self.get() in self._refs + + def unwatch(self, tid): + self._refs.pop(tid, None) + + def get(self): + """An id for this thread or greenlet""" + raise NotImplementedError + + def watch(self, callback): + """Run callback when this thread or greenlet dies. callback takes + one meaningless argument. + """ + raise NotImplementedError + + +class ThreadIdent(Ident): + class _DummyLock(object): + def acquire(self): + pass + + def release(self): + pass + + def __init__(self): + super(ThreadIdent, self).__init__() + self._local = threading.local() + if issue1868: + self._lock = threading.Lock() + else: + self._lock = ThreadIdent._DummyLock() + + # We watch for thread-death using a weakref callback to a thread local. + # Weakrefs are permitted on subclasses of object but not object() itself. + class ThreadVigil(object): + pass + + def _make_vigil(self): + # Threadlocals in Python <= 2.7.0 have race conditions when setting + # attributes and possibly when getting them, too, leading to weakref + # callbacks not getting called later. + self._lock.acquire() + try: + vigil = getattr(self._local, 'vigil', None) + if not vigil: + self._local.vigil = vigil = ThreadIdent.ThreadVigil() + finally: + self._lock.release() + + return vigil + + def get(self): + return id(self._make_vigil()) + + def watch(self, callback): + vigil = self._make_vigil() + self._refs[id(vigil)] = weakref.ref(vigil, callback) + + +class GreenletIdent(Ident): + def get(self): + return id(greenlet.getcurrent()) + + def watch(self, callback): + current = greenlet.getcurrent() + tid = self.get() + + if hasattr(current, 'link'): + # This is a Gevent Greenlet (capital G), which inherits from + # greenlet and provides a 'link' method to detect when the + # Greenlet exits. + link = SpawnedLink(callback) + current.rawlink(link) + self._refs[tid] = link + else: + # This is a non-Gevent greenlet (small g), or it's the main + # greenlet. + self._refs[tid] = weakref.ref(current, callback) + + def unwatch(self, tid): + """ call unlink if link before """ + link = self._refs.pop(tid, None) + current = greenlet.getcurrent() + if hasattr(current, 'unlink'): + # This is a Gevent enhanced Greenlet. Remove the SpawnedLink we + # linked to it. + current.unlink(link) + + +def create_ident(use_greenlets): + if use_greenlets: + return GreenletIdent() + else: + return ThreadIdent() + + +class Counter(object): + """A thread- or greenlet-local counter. + """ + def __init__(self, use_greenlets): + self.ident = create_ident(use_greenlets) + self._counters = {} + + def inc(self): + # Copy these references so on_thread_died needn't close over self + ident = self.ident + _counters = self._counters + + tid = ident.get() + _counters.setdefault(tid, 0) + _counters[tid] += 1 + + if not ident.watching(): + # Before the tid is possibly reused, remove it from _counters + def on_thread_died(ref): + ident.unwatch(tid) + _counters.pop(tid, None) + + ident.watch(on_thread_died) + + return _counters[tid] + + def dec(self): + tid = self.ident.get() + if self._counters.get(tid, 0) > 0: + self._counters[tid] -= 1 + return self._counters[tid] + else: + return 0 + + def get(self): + return self._counters.get(self.ident.get(), 0) + + +### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire +class Semaphore: + + # After Tim Peters' semaphore class, but not quite the same (no maximum) + + def __init__(self, value=1): + if value < 0: + raise ValueError("semaphore initial value must be >= 0") + self._cond = threading.Condition(threading.Lock()) + self._value = value + + def acquire(self, blocking=True, timeout=None): + if not blocking and timeout is not None: + raise ValueError("can't specify timeout for non-blocking acquire") + rc = False + endtime = None + self._cond.acquire() + while self._value == 0: + if not blocking: + break + if timeout is not None: + if endtime is None: + endtime = _time() + timeout + else: + timeout = endtime - _time() + if timeout <= 0: + break + self._cond.wait(timeout) + else: + self._value = self._value - 1 + rc = True + self._cond.release() + return rc + + __enter__ = acquire + + def release(self): + self._cond.acquire() + self._value = self._value + 1 + self._cond.notify() + self._cond.release() + + def __exit__(self, t, v, tb): + self.release() + + @property + def counter(self): + return self._value + + +class BoundedSemaphore(Semaphore): + """Semaphore that checks that # releases is <= # acquires""" + def __init__(self, value=1): + Semaphore.__init__(self, value) + self._initial_value = value + + def release(self): + if self._value >= self._initial_value: + raise ValueError("Semaphore released too many times") + return Semaphore.release(self) +### End backport from CPython 3.2 + + +class DummySemaphore(object): + def __init__(self, value=None): + pass + + def acquire(self, blocking=True, timeout=None): + return True + + def release(self): + pass + + +class MaxWaitersBoundedSemaphore(object): + def __init__(self, semaphore_class, value=1, max_waiters=1): + self.waiter_semaphore = semaphore_class(max_waiters) + self.semaphore = semaphore_class(value) + + def acquire(self, blocking=True, timeout=None): + if not self.waiter_semaphore.acquire(False): + raise ExceededMaxWaiters() + try: + return self.semaphore.acquire(blocking, timeout) + finally: + self.waiter_semaphore.release() + + def __getattr__(self, name): + return getattr(self.semaphore, name) + + +class MaxWaitersBoundedSemaphoreThread(MaxWaitersBoundedSemaphore): + def __init__(self, value=1, max_waiters=1): + MaxWaitersBoundedSemaphore.__init__( + self, BoundedSemaphore, value, max_waiters) + + +if have_gevent: + class MaxWaitersBoundedSemaphoreGevent(MaxWaitersBoundedSemaphore): + def __init__(self, value=1, max_waiters=1): + MaxWaitersBoundedSemaphore.__init__( + self, GeventBoundedSemaphore, value, max_waiters) + + +def create_semaphore(max_size, max_waiters, use_greenlets): + if max_size is None: + return DummySemaphore() + elif use_greenlets: + if max_waiters is None: + return GeventBoundedSemaphore(max_size) + else: + return MaxWaitersBoundedSemaphoreGevent(max_size, max_waiters) + else: + if max_waiters is None: + return BoundedSemaphore(max_size) + else: + return MaxWaitersBoundedSemaphoreThread(max_size, max_waiters) diff --git a/asyncio_mongo/_pymongo/uri_parser.py b/asyncio_mongo/_pymongo/uri_parser.py new file mode 100644 index 0000000..7696087 --- /dev/null +++ b/asyncio_mongo/_pymongo/uri_parser.py @@ -0,0 +1,301 @@ +# Copyright 2011-2012 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Tools to parse and validate a MongoDB URI.""" + +from urllib.parse import unquote_plus + +from asyncio_mongo._pymongo.common import validate +from asyncio_mongo._pymongo.errors import (ConfigurationError, + InvalidURI, + UnsupportedOption) + +SCHEME = 'mongodb://' +SCHEME_LEN = len(SCHEME) +DEFAULT_PORT = 27017 + + +def _partition(entity, sep): + """Python2.4 doesn't have a partition method so we provide + our own that mimics str.partition from later releases. + + Split the string at the first occurrence of sep, and return a + 3-tuple containing the part before the separator, the separator + itself, and the part after the separator. If the separator is not + found, return a 3-tuple containing the string itself, followed + by two empty strings. + """ + parts = entity.split(sep, 1) + if len(parts) == 2: + return parts[0], sep, parts[1] + else: + return entity, '', '' + + +def _rpartition(entity, sep): + """Python2.4 doesn't have an rpartition method so we provide + our own that mimics str.rpartition from later releases. + + Split the string at the last occurrence of sep, and return a + 3-tuple containing the part before the separator, the separator + itself, and the part after the separator. If the separator is not + found, return a 3-tuple containing two empty strings, followed + by the string itself. + """ + idx = entity.rfind(sep) + if idx == -1: + return '', '', entity + return entity[:idx], sep, entity[idx + 1:] + + +def parse_userinfo(userinfo): + """Validates the format of user information in a MongoDB URI. + Reserved characters like ':', '/', '+' and '@' must be escaped + following RFC 2396. + + Returns a 2-tuple containing the unescaped username followed + by the unescaped password. + + :Paramaters: + - `userinfo`: A string of the form : + + .. versionchanged:: 2.2 + Now uses `urllib.unquote_plus` so `+` characters must be escaped. + """ + if '@' in userinfo or userinfo.count(':') > 1: + raise InvalidURI("':' or '@' characters in a username or password " + "must be escaped according to RFC 2396.") + user, _, passwd = _partition(userinfo, ":") + # No password is expected with GSSAPI authentication. + if not user: + raise InvalidURI("The empty string is not valid username.") + user = unquote_plus(user) + passwd = unquote_plus(passwd) + + return user, passwd + + +def parse_ipv6_literal_host(entity, default_port): + """Validates an IPv6 literal host:port string. + + Returns a 2-tuple of IPv6 literal followed by port where + port is default_port if it wasn't specified in entity. + + :Parameters: + - `entity`: A string that represents an IPv6 literal enclosed + in braces (e.g. '[::1]' or '[::1]:27017'). + - `default_port`: The port number to use when one wasn't + specified in entity. + """ + if entity.find(']') == -1: + raise ConfigurationError("an IPv6 address literal must be " + "enclosed in '[' and ']' according " + "to RFC 2732.") + i = entity.find(']:') + if i == -1: + return entity[1:-1], default_port + return entity[1: i], entity[i + 2:] + + +def parse_host(entity, default_port=DEFAULT_PORT): + """Validates a host string + + Returns a 2-tuple of host followed by port where port is default_port + if it wasn't specified in the string. + + :Parameters: + - `entity`: A host or host:port string where host could be a + hostname or IP address. + - `default_port`: The port number to use when one wasn't + specified in entity. + """ + host = entity + port = default_port + if entity[0] == '[': + host, port = parse_ipv6_literal_host(entity, default_port) + elif entity.find(':') != -1: + if entity.count(':') > 1: + raise ConfigurationError("Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732.") + host, port = host.split(':', 1) + if isinstance(port, str): + if not port.isdigit(): + raise ConfigurationError("Port number must be an integer.") + port = int(port) + return host, port + + +def validate_options(opts): + """Validates and normalizes options passed in a MongoDB URI. + + Returns a new dictionary of validated and normalized options. + + :Parameters: + - `opts`: A dict of MongoDB URI options. + """ + normalized = {} + for option, value in opts.items(): + option, value = validate(option, value) + # str(option) to ensure that a unicode URI results in plain 'str' + # option names. 'normalized' is then suitable to be passed as kwargs + # in all Python versions. + normalized[str(option)] = value + return normalized + + +def split_options(opts): + """Takes the options portion of a MongoDB URI, validates each option + and returns the options in a dictionary. The option names will be returned + lowercase even if camelCase options are used. + + :Parameters: + - `opt`: A string representing MongoDB URI options. + """ + and_idx = opts.find("&") + semi_idx = opts.find(";") + try: + if and_idx >= 0 and semi_idx >= 0: + raise InvalidURI("Can not mix '&' and ';' for option separators.") + elif and_idx >= 0: + options = dict([kv.split("=") for kv in opts.split("&")]) + elif semi_idx >= 0: + options = dict([kv.split("=") for kv in opts.split(";")]) + elif opts.find("=") != -1: + options = dict([opts.split("=")]) + else: + raise ValueError + except ValueError: + raise InvalidURI("MongoDB URI options are key=value pairs.") + + return validate_options(options) + + +def split_hosts(hosts, default_port=DEFAULT_PORT): + """Takes a string of the form host1[:port],host2[:port]... and + splits it into (host, port) tuples. If [:port] isn't present the + default_port is used. + + Returns a set of 2-tuples containing the host name (or IP) followed by + port number. + + :Parameters: + - `hosts`: A string of the form host1[:port],host2[:port],... + - `default_port`: The port number to use when one wasn't specified + for a host. + """ + nodes = [] + for entity in hosts.split(','): + if not entity: + raise ConfigurationError("Empty host " + "(or extra comma in host list).") + port = default_port + # Unix socket entities don't have ports + if entity.endswith('.sock'): + port = None + nodes.append(parse_host(entity, port)) + return nodes + + +def parse_uri(uri, default_port=DEFAULT_PORT): + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': + } + + :Parameters: + - `uri`: The MongoDB URI to parse. + - `default_port`: The port number to use when one wasn't specified + for a host in the URI. + """ + if not uri.startswith(SCHEME): + raise InvalidURI("Invalid URI scheme: URI " + "must begin with '%s'" % (SCHEME,)) + + scheme_free = uri[SCHEME_LEN:] + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP.") + + nodes = None + user = None + passwd = None + dbase = None + collection = None + options = {} + + # Check for unix domain sockets in the uri + if '.sock' in scheme_free: + host_part, _, path_part = _rpartition(scheme_free, '/') + try: + parse_uri('%s%s' % (SCHEME, host_part)) + except (ConfigurationError, InvalidURI): + host_part = scheme_free + path_part = "" + else: + host_part, _, path_part = _partition(scheme_free, '/') + + if not path_part and '?' in host_part: + raise InvalidURI("A '/' is required between " + "the host list and any options.") + + if '@' in host_part: + userinfo, _, hosts = _rpartition(host_part, '@') + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + nodes = split_hosts(hosts, default_port=default_port) + + if path_part: + + if path_part[0] == '?': + opts = path_part[1:] + else: + dbase, _, opts = _partition(path_part, '?') + if '.' in dbase: + dbase, collection = dbase.split('.', 1) + + if opts: + options = split_options(opts) + + return { + 'nodelist': nodes, + 'username': user, + 'password': passwd, + 'database': dbase, + 'collection': collection, + 'options': options + } + + +if __name__ == '__main__': + import pprint + import sys + try: + pprint.pprint(parse_uri(sys.argv[1])) + except (InvalidURI, UnsupportedOption) as e: + print(e) + sys.exit(0) + diff --git a/asyncio_mongo/collection.py b/asyncio_mongo/collection.py new file mode 100644 index 0000000..b1ff6b8 --- /dev/null +++ b/asyncio_mongo/collection.py @@ -0,0 +1,383 @@ +# coding: utf-8 +# Copyright 2009 Alexandre Fiori +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from asyncio import coroutine +from asyncio_mongo import filter as qf +from asyncio_mongo._bson import SON, ObjectId, Code +from asyncio_mongo._pymongo import errors + + +class Collection(object): + def __init__(self, database, name): + if not isinstance(name, str): + raise TypeError("name must be an instance of str") + + if not name or ".." in name: + raise errors.InvalidName("collection names cannot be empty") + if "$" in name and not (name.startswith("oplog.$main") or + name.startswith("$cmd")): + raise errors.InvalidName("collection names must not " + "contain '$': %r" % name) + if name[0] == "." or name[-1] == ".": + raise errors.InvalidName("collection names must not start " + "or end with '.': %r" % name) + if "\x00" in name: + raise errors.InvalidName("collection names must not contain the " + "null character") + + self._database = database + self._collection_name = name + + def __str__(self): + return "%s.%s" % (str(self._database), self._collection_name) + + def __repr__(self): + return "" % str(self) + + def __getitem__(self, collection_name): + return Collection(self._database, + "%s.%s" % (self._collection_name, collection_name)) + + def __eq__(self, other): + if isinstance(other, Collection): + return (self._database, self._collection_name) == \ + (other._database, other._collection_name) + return NotImplemented + + def __hash__(self): + return self._collection_name.__hash__() + + def __getattr__(self, collection_name): + return self[collection_name] + + def __call__(self, collection_name): + return self[collection_name] + + def _fields_list_to_dict(self, fields): + """ + transform a list of fields from ["a", "b"] to {"a":1, "b":1} + """ + as_dict = {} + for field in fields: + if not isinstance(field, str): + raise TypeError("fields must be a list of key names") + as_dict[field] = 1 + return as_dict + + def _gen_index_name(self, keys): + return u"_".join([u"%s_%s" % item for item in keys]) + + @coroutine + def options(self): + result = yield from self._database.system.namespaces.find_one({"name": str(self)}) + if result: + options = result.get("options", {}) + if "create" in options: + del options["create"] + return options + return {} + + @coroutine + def find(self, spec=None, skip=0, limit=0, fields=None, filter=None, _proto=None): + if spec is None: + spec = SON() + + if not isinstance(spec, dict): + raise TypeError("spec must be an instance of dict") + if fields is not None and not isinstance(fields, (dict, list)): + raise TypeError("fields must be an instance of dict or list") + if not isinstance(skip, int): + raise TypeError("skip must be an instance of int") + if not isinstance(limit, int): + raise TypeError("limit must be an instance of int") + + if fields is not None: + if not isinstance(fields, dict): + if not fields: + fields = ["_id"] + fields = self._fields_list_to_dict(fields) + + if isinstance(filter, (qf.sort, qf.hint, qf.explain, qf.snapshot)): + spec = SON(dict(query=spec)) + for k, v in filter.items(): + spec[k] = isinstance(v, tuple) and SON(v) or v + + # send the command through a specific connection + # this is required for the connection pool to work + # when safe=True + if _proto is None: + proto = self._database._protocol + else: + proto = _proto + return (yield from proto.OP_QUERY(str(self), spec, skip, limit, fields)) + + @coroutine + def find_one(self, spec=None, fields=None, _proto=None): + if isinstance(spec, ObjectId): + spec = SON(dict(_id=spec)) + + docs = yield from self.find(spec, limit=-1, fields=fields, _proto=_proto) + doc = docs and docs[0] or {} + if doc.get("err") is not None: + if doc.get("code") == 11000: + raise errors.DuplicateKeyError + else: + raise errors.OperationFailure(doc) + else: + return doc + + @coroutine + def count(self, spec=None, fields=None): + if fields is not None: + if not fields: + fields = ["_id"] + fields = self._fields_list_to_dict(fields) + + spec = SON([("count", self._collection_name), + ("query", spec or SON()), + ("fields", fields)]) + result = yield from self._database["$cmd"].find_one(spec) + return result["n"] + + @coroutine + def group(self, keys, initial, reduce, condition=None, finalize=None): + body = { + "ns": self._collection_name, + "key": self._fields_list_to_dict(keys), + "initial": initial, + "$reduce": Code(reduce), + } + + if condition: + body["cond"] = condition + if finalize: + body["finalize"] = Code(finalize) + + return (yield from self._database["$cmd"].find_one({"group": body})) + + @coroutine + def filemd5(self, spec): + if not isinstance(spec, ObjectId): + raise ValueError(_("filemd5 expected an objectid for its " + "on-keyword argument")) + + spec = SON([("filemd5", spec), + ("root", self._collection_name)]) + + result = yield from self._database['$cmd'].find_one(spec) + return result.get('md5') + + @coroutine + def __safe_operation(self, proto, safe=False, ids=None): + callit = False + result = None + if safe is True: + result = yield from self._database["$cmd"].find_one({"getlasterror": 1}, _proto=proto) + else: + callit = True + + if ids is not None: + return ids + + if callit is True: + return None + + return result + + @coroutine + def insert(self, docs, safe=False): + if isinstance(docs, dict): + ids = docs.get('_id', ObjectId()) + docs["_id"] = ids + docs = [docs] + elif isinstance(docs, list): + ids = [] + for doc in docs: + if isinstance(doc, dict): + id = doc.get('_id', ObjectId()) + ids.append(id) + doc["_id"] = id + else: + raise TypeError("insert takes a document or a list of documents") + else: + raise TypeError("insert takes a document or a list of documents") + proto = self._database._protocol + proto.OP_INSERT(str(self), docs) + result = yield from self.__safe_operation(proto, safe, ids) + return result + + @coroutine + def update(self, spec, document, upsert=False, multi=False, safe=False): + if not isinstance(spec, dict): + raise TypeError("spec must be an instance of dict") + if not isinstance(document, dict): + raise TypeError("document must be an instance of dict") + if not isinstance(upsert, bool): + raise TypeError("upsert must be an instance of bool") + proto = self._database._protocol + proto.OP_UPDATE(str(self), spec, document, upsert, multi) + return (yield from self.__safe_operation(proto, safe)) + + @coroutine + def save(self, doc, safe=False): + if not isinstance(doc, dict): + raise TypeError("cannot save objects of type %s" % type(doc)) + + objid = doc.get("_id") + if objid: + return (yield from self.update({"_id": objid}, doc, safe=safe, upsert=True)) + else: + return (yield from self.insert(doc, safe=safe)) + + @coroutine + def remove(self, spec, safe=False): + if isinstance(spec, ObjectId): + spec = SON(dict(_id=spec)) + if not isinstance(spec, dict): + raise TypeError("spec must be an instance of dict, not %s" % type(spec)) + + proto = self._database._protocol + proto.OP_DELETE(str(self), spec) + return (yield from self.__safe_operation(proto, safe)) + + @coroutine + def drop(self, safe=False): + return (yield from self.remove({}, safe)) + + @coroutine + def create_index(self, sort_fields, **kwargs): + if not isinstance(sort_fields, qf.sort): + raise TypeError("sort_fields must be an instance of filter.sort") + + if "name" not in kwargs: + name = self._gen_index_name(sort_fields["orderby"]) + else: + name = kwargs.pop("name") + + key = SON() + for k,v in sort_fields["orderby"]: + key.update({k:v}) + + index = SON(dict( + ns=str(self), + name=name, + key=key + )) + + if "drop_dups" in kwargs: + kwargs["dropDups"] = kwargs.pop("drop_dups") + + if "bucket_size" in kwargs: + kwargs["bucketSize"] = kwargs.pop("bucket_size") + + index.update(kwargs) + yield from self._database.system.indexes.insert(index, safe=True) + return name + + @coroutine + def ensure_index(self, sort_fields, **kwargs): + # ensure_index is an alias of create_index since we are not + # keep an index cache same way pymongo does + return (yield from self.create_index(sort_fields, **kwargs)) + + @coroutine + def drop_index(self, index_identifier): + if isinstance(index_identifier, str): + name = index_identifier + elif isinstance(index_identifier, qf.sort): + name = self._gen_index_name(index_identifier["orderby"]) + else: + raise TypeError("index_identifier must be a name or instance of filter.sort") + + cmd = SON([("deleteIndexes", self._collection_name), ("index", name)]) + return (yield from self._database["$cmd"].find_one(cmd)) + + @coroutine + def drop_indexes(self): + return (yield from self.drop_index("*")) + + @coroutine + def index_information(self): + raw = yield from self._database.system.indexes.find({"ns": str(self)}) + info = {} + for idx in raw: + info[idx["name"]] = idx["key"].items() + return info + + @coroutine + def rename(self, new_name): + cmd = SON([("renameCollection", str(self)), ("to", "%s.%s" % \ + (str(self._database), new_name))]) + return (yield from self._database("admin")["$cmd"].find_one(cmd)) + + @coroutine + def distinct(self, key, spec=None): + + cmd = SON([("distinct", self._collection_name), ("key", key)]) + if spec: + cmd["query"] = spec + + result = yield from self._database["$cmd"].find_one(cmd) + if result: + return result.get("values") + return {} + + @coroutine + def aggregate(self, pipeline, full_response=False): + + cmd = SON([("aggregate", self._collection_name), + ("pipeline", pipeline)]) + + result = yield from self._database["$cmd"].find_one(cmd) + if full_response: + return result + return result.get("result") + + @coroutine + def map_reduce(self, map, reduce, full_response=False, **kwargs): + + cmd = SON([("mapreduce", self._collection_name), ("map", map), ("reduce", reduce)]) + cmd.update(**kwargs) + result = yield from self._database["$cmd"].find_one(cmd) + if full_response: + return result + return result.get("result") + + @coroutine + def find_and_modify(self, query=None, update=None, upsert=False, **kwargs): + if not update and not kwargs.get('remove', None): + raise ValueError("Must either update or remove") + + if update and kwargs.get('remove', None): + raise ValueError("Can't do both update and remove") + + cmd = SON([("findAndModify", self._collection_name)]) + cmd.update(kwargs) + # No need to include empty args + if query: + cmd['query'] = query + if update: + cmd['update'] = update + if upsert: + cmd['upsert'] = upsert + + result = yield from self._database["$cmd"].find_one(cmd) + no_obj_error = "No matching object found" + if not result['ok']: + if result["errmsg"] == no_obj_error: + return None + else: + raise ValueError("Unexpected Error: %s" % (result,)) + return result.get('value') \ No newline at end of file diff --git a/asyncio_mongo/connection.py b/asyncio_mongo/connection.py new file mode 100644 index 0000000..e3c1cfa --- /dev/null +++ b/asyncio_mongo/connection.py @@ -0,0 +1,93 @@ +from asyncio_mongo.database import Database +from .protocol import MongoProtocol +from asyncio.log import logger +import asyncio +import logging + +__all__ = ['Connection'] + + +class Connection: + """ + Wrapper around the protocol and transport which takes care of establishing + the connection and reconnecting it. + + :: + + connection = yield from Connection.create(host='localhost', port=6379) + result = yield from connection.set('key', 'value') + """ + protocol = MongoProtocol + """ + The :class:`MongoProtocol` class to be used this connection. + """ + + @classmethod + @asyncio.coroutine + def create(cls, host='localhost', port=6379, loop=None, password=None, db=0, auto_reconnect=True): + connection = cls() + + connection.host = host + connection.port = port + connection._loop = loop + connection._retry_interval = .5 + + # Create protocol instance + protocol_factory = type('MongoProtocol', (cls.protocol,), { 'password': password, 'db': db }) + + if auto_reconnect: + class protocol_factory(protocol_factory): + def connection_lost(self, exc): + super().connection_lost(exc) + asyncio.Task(connection._reconnect()) + + connection.protocol = protocol_factory() + + # Connect + yield from connection._reconnect() + + return connection + + @property + def transport(self): + """ The transport instance that the protocol is currently using. """ + return self.protocol.transport + + def _get_retry_interval(self): + """ Time to wait for a reconnect in seconds. """ + return self._retry_interval + + def _reset_retry_interval(self): + """ Set the initial retry interval. """ + self._retry_interval = .5 + + def _increase_retry_interval(self): + """ When a connection failed. Increase the interval.""" + self._retry_interval = min(60, 1.5 * self._retry_interval) + + def _reconnect(self): + """ + Set up Mongo connection. + """ + loop = self._loop or asyncio.get_event_loop() + while True: + try: + logger.log(logging.INFO, 'Connecting to mongo') + yield from loop.create_connection(lambda: self.protocol, self.host, self.port) + self._reset_retry_interval() + return + except OSError: + # Sleep and try again + self._increase_retry_interval() + interval = self._get_retry_interval() + logger.log(logging.INFO, 'Connecting to mongo failed. Retrying in %i seconds' % interval) + yield from asyncio.sleep(interval) + + def __getitem__(self, database_name): + return Database(self.protocol, database_name) + + def __getattr__(self, database_name): + return self[database_name] + + def __repr__(self): + return 'Connection(host=%r, port=%r)' % (self.host, self.port) diff --git a/asyncio_mongo/database.py b/asyncio_mongo/database.py new file mode 100644 index 0000000..8989c32 --- /dev/null +++ b/asyncio_mongo/database.py @@ -0,0 +1,123 @@ +# coding: utf-8 +# Copyright 2009 Alexandre Fiori +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from asyncio_mongo._bson import SON + +from asyncio_mongo._pymongo import helpers +from asyncio_mongo.collection import Collection +from asyncio import coroutine +from asyncio_mongo.exceptions import ErrorReply + + +class Database(object): + def __init__(self, protocol, database_name): + self.__protocol = protocol + self._database_name = database_name + + def __str__(self): + return self._database_name + + def __repr__(self): + return "" % self._database_name + + def __call__(self, database_name): + return Database(self.__protocol, database_name) + + def __getitem__(self, collection_name): + return Collection(self, collection_name) + + def __getattr__(self, collection_name): + return self[collection_name] + + @property + def _protocol(self): + return self.__protocol + + @coroutine + def create_collection(self, name, options=None): + collection = Collection(self, name) + + if options: + if "size" in options: + options["size"] = float(options["size"]) + + command = SON({"create": name}) + command.update(options) + result = yield from self["$cmd"].find_one(command) + if result.get("ok", 0.0): + return collection + else: + raise RuntimeError(result.get("errmsg", "unknown error")) + else: + return collection + + @coroutine + def drop_collection(self, name_or_collection): + if isinstance(name_or_collection, Collection): + name = name_or_collection._collection_name + elif isinstance(name_or_collection, str): + name = name_or_collection + else: + raise TypeError("name must be an instance of basestring or txmongo.Collection") + + return self["$cmd"].find_one({"drop": name}) + + @coroutine + def collection_names(self): + results = yield from self["system.namespaces"].find() + names = [r["name"] for r in results] + names = [n[len(str(self)) + 1:] for n in names + if n.startswith(str(self) + ".")] + names = [n for n in names if "$" not in n] + return names + + @coroutine + def authenticate(self, name, password): + """ + Send an authentication command for this database. + mostly stolen from asyncio_mongo._pymongo + """ + if not isinstance(name, str): + raise TypeError("name must be an instance of str") + if not isinstance(password, str): + raise TypeError("password must be an instance of str") + + # First get the nonce + result = yield self["$cmd"].find_one({"getnonce": 1}) + return (yield self.authenticate_with_nonce(result, name, password)) + + @coroutine + def authenticate_with_nonce(self, result, name, password): + nonce = result['nonce'] + key = helpers._auth_key(nonce, name, password) + + # hacky because order matters + auth_command = SON(authenticate=1) + auth_command['user'] = name + auth_command['nonce'] = nonce + auth_command['key'] = key + + # Now actually authenticate + result = yield from self["$cmd"].find_one(auth_command) + return self.authenticated(result) + + @coroutine + def authenticated(self, result): + """might want to just call callback with 0.0 instead of errback""" + ok = result['ok'] + if ok: + return ok + else: + raise ErrorReply(result['errmsg']) + diff --git a/asyncio_mongo/exceptions.py b/asyncio_mongo/exceptions.py new file mode 100644 index 0000000..28cb2f8 --- /dev/null +++ b/asyncio_mongo/exceptions.py @@ -0,0 +1,54 @@ +__all__ = ( + 'ConnectionLostError', + 'Error', + 'ErrorReply', + 'NoAvailableConnectionsInPoolError', + 'NoRunningScriptError', + 'NotConnectedError', + 'ScriptKilledError', + 'TransactionError', +) + + +# See following link for the proper way to create user defined exceptions: +# http://docs.python.org/3.3/tutorial/errors.html#user-defined-exceptions + + +class Error(Exception): + """ Base exception. """ + + +class ErrorReply(Exception): + """ Exception when the mongo server returns an error. """ + + +class TransactionError(Error): + """ Transaction failed. """ + + +class NotConnectedError(Error): + """ Protocol is not connected. """ + def __init__(self, message='Not connected'): + super().__init__(message) + + +class ConnectionLostError(NotConnectedError): + """ + Connection lost during query. + (Special case of ``NotConnectedError``.) + """ + def __init__(self, exc): + self.exception = exc + + +class NoAvailableConnectionsInPoolError(NotConnectedError): + """ + When the connection pool has no available connections. + """ + +class ScriptKilledError(Error): + """ Script was killed during an evalsha call. """ + + +class NoRunningScriptError(Error): + """ script_kill was called while no script was running. """ diff --git a/asyncio_mongo/filter.py b/asyncio_mongo/filter.py new file mode 100644 index 0000000..8627bf4 --- /dev/null +++ b/asyncio_mongo/filter.py @@ -0,0 +1,119 @@ +# coding: utf-8 +# Copyright 2009 Alexandre Fiori +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +"""Query filters""" + + +def _DIRECTION(keys, direction): + if isinstance(keys, str): + return (keys, direction), + elif isinstance(keys, (list, tuple)): + return tuple([(k, direction) for k in keys]) + + +def ASCENDING(keys): + """Ascending sort order""" + return _DIRECTION(keys, 1) + + +def DESCENDING(keys): + """Descending sort order""" + return _DIRECTION(keys, -1) + + +def GEO2D(keys): + """ + Two-dimensional geospatial index + http://www.mongodb.org/display/DOCS/Geospatial+Indexing + """ + return _DIRECTION(keys, "2d") + + +def GEOHAYSTACK(keys): + """ + Bucket-based geospatial index + http://www.mongodb.org/display/DOCS/Geospatial+Haystack+Indexing + """ + return _DIRECTION(keys, "geoHaystack") + + + +class _QueryFilter(defaultdict): + def __init__(self): + defaultdict.__init__(self, lambda: ()) + + def __add__(self, obj): + for k, v in obj.items(): + if isinstance(v, tuple): + self[k] += v + else: + self[k] = v + return self + + def _index_document(self, operation, index_list): + name = self.__class__.__name__ + try: + assert isinstance(index_list, (list, tuple)) + for key, direction in index_list: + if not isinstance(key, str): + raise TypeError("Invalid %sing key: %s" % (name, repr(key))) + if direction not in (1, -1, "2d", "geoHaystack"): + raise TypeError("Invalid %sing direction: %s" % (name, direction)) + self[operation] += tuple(((key, direction),)) + except Exception: + raise TypeError("Invalid list of keys for %s: %s" % (name, repr(index_list))) + + def __repr__(self): + return "" % dict.__repr__(self) + + +class sort(_QueryFilter): + """Sorts the results of a query.""" + + def __init__(self, key_list): + _QueryFilter.__init__(self) + try: + assert isinstance(key_list[0], (list, tuple)) + except: + key_list = (key_list,) + self._index_document("orderby", key_list) + + +class hint(_QueryFilter): + """Adds a `hint`, telling Mongo the proper index to use for the query.""" + + def __init__(self, index_list): + _QueryFilter.__init__(self) + try: + assert isinstance(index_list[0], (list, tuple)) + except: + index_list = (index_list,) + self._index_document("$hint", index_list) + + +class explain(_QueryFilter): + """Returns an explain plan for the query.""" + + def __init__(self): + _QueryFilter.__init__(self) + self["explain"] = True + + +class snapshot(_QueryFilter): + def __init__(self): + _QueryFilter.__init__(self) + self["snapshot"] = True diff --git a/asyncio_mongo/log.py b/asyncio_mongo/log.py new file mode 100644 index 0000000..23a7074 --- /dev/null +++ b/asyncio_mongo/log.py @@ -0,0 +1,7 @@ +"""Logging configuration.""" + +import logging + + +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/asyncio_mongo/pool.py b/asyncio_mongo/pool.py new file mode 100644 index 0000000..8ac11de --- /dev/null +++ b/asyncio_mongo/pool.py @@ -0,0 +1,112 @@ +from .connection import Connection +from .exceptions import NoAvailableConnectionsInPoolError +from .protocol import MongoProtocol +import asyncio + + +__all__ = ('Pool', ) + + +class Pool: + """ + Pool of connections. Each + Takes care of setting up the connection and connection pooling. + + When poolsize > 1 and some connections are in use because of transactions + or blocking requests, the other are preferred. + + :: + + pool = yield from Pool.create(host='localhost', port=6379, poolsize=10) + result = yield from connection.set('key', 'value') + """ + + protocol = MongoProtocol + """ + The :class:`MongoProtocol` class to be used for each connection in this pool. + """ + + @classmethod + def get_connection_class(cls): + """ + Return the :class:`Connection` class to be used for every connection in + this pool. Normally this is just a ``Connection`` using the defined ``protocol``. + """ + class ConnectionClass(Connection): + protocol = cls.protocol + return ConnectionClass + + @classmethod + @asyncio.coroutine + def create(cls, host='localhost', port=6379, loop=None, password=None, db=0, poolsize=1, auto_reconnect=True): + """ + Create a new connection instance. + """ + self = cls() + self._host = host + self._port = port + self._poolsize = poolsize + + # Create connections + self._connections = [] + + for i in range(poolsize): + connection_class = cls.get_connection_class() + connection = yield from connection_class.create(host=host, port=port, loop=loop, + password=password, db=db, auto_reconnect=auto_reconnect) + self._connections.append(connection) + + return self + + def __repr__(self): + return 'Pool(host=%r, port=%r, poolsize=%r)' % (self._host, self._port, self._poolsize) + + @property + def poolsize(self): + """ Number of parallel connections in the pool.""" + return self._poolsize + + @property + def connections_in_use(self): + """ + Return how many protocols are in use. + """ + return sum([ 1 for c in self._connections if c.protocol.in_use ]) + + @property + def connections_connected(self): + """ + The amount of open TCP connections. + """ + return sum([ 1 for c in self._connections if c.protocol.is_connected ]) + + def _get_free_connection(self): + """ + Return the next protocol instance that's not in use. + (A protocol in pubsub mode or doing a blocking request is considered busy, + and can't be used for anything else.) + """ + self._shuffle_connections() + + for c in self._connections: + if c.protocol.is_connected and not c.protocol.in_use: + return c + + def _shuffle_connections(self): + """ + 'shuffle' protocols. Make sure that we devide the load equally among the protocols. + """ + self._connections = self._connections[1:] + self._connections[:1] + + def __getattr__(self, name): + """ + Proxy to a protocol. (This will choose a protocol instance that's not + busy in a blocking request or transaction.) + """ + connection = self._get_free_connection() + + if connection: + return getattr(connection, name) + else: + raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, in_use=%s, connected=%s' % ( + self.poolsize, self.connections_in_use, self.connections_connected)) diff --git a/asyncio_mongo/protocol.py b/asyncio_mongo/protocol.py new file mode 100644 index 0000000..5f8e90f --- /dev/null +++ b/asyncio_mongo/protocol.py @@ -0,0 +1,191 @@ +# coding: utf-8 +# Copyright 2009 Alexandre Fiori +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import struct +import asyncio +from asyncio_mongo.exceptions import ConnectionLostError +import asyncio_mongo._bson as bson +from asyncio_mongo.log import logger + +_ONE = b"\x01\x00\x00\x00" +_ZERO = b"\x00\x00\x00\x00" + +"""Low level connection to Mongo.""" + + +class _MongoQuery(object): + def __init__(self, id, collection, limit): + self.id = id + self.limit = limit + self.collection = collection + self.documents = [] + self.future = asyncio.Future() + + +class MongoProtocol(asyncio.Protocol): + def __init__(self): + self.__id = 0 + self.__buffer = b"" + self.__queries = {} + self.__datalen = None + self.__response = 0 + self.__waiting_header = True + self._pipelined_calls = set() # Set of all the pipelined calls. + self.transport = None + self._is_connected = False + + def connection_made(self, transport): + self.transport = transport + self._is_connected = True + logger.log(logging.INFO, 'Mongo connection made with %') + + def connection_lost(self, exc): + self._is_connected = False + self.transport = None + + # Raise exception on all waiting futures. + for f in self.__queries: + f.set_exception(ConnectionLostError(exc)) + + logger.log(logging.INFO, 'Mongo connection lost') + + def data_received(self, data): + while self.__waiting_header: + self.__buffer += data + if len(self.__buffer) < 16: + break + + # got full header, 16 bytes (or more) + header, extra = self.__buffer[:16], self.__buffer[16:] + self.__buffer = b"" + self.__waiting_header = False + datalen, request, response, operation = struct.unpack("= 0, "Unexpected number of documents received!" + if not next_batch: + self.OP_KILL_CURSORS([cursor_id]) + query.future.set_result(query.documents) + return + self.__queries[self.__id] = query + self.OP_GET_MORE(query.collection, next_batch, cursor_id) + else: + query.future.set_result(query.documents) \ No newline at end of file diff --git a/examples/aggregate.py b/examples/aggregate.py new file mode 100644 index 0000000..d7f2177 --- /dev/null +++ b/examples/aggregate.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# coding: utf-8 +from asyncio import coroutine +import asyncio + +import asyncio_mongo + +@coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + yield from test.insert({"src":"Twitter", "content":"bla bla"}, safe=True) + yield from test.insert({"src":"Twitter", "content":"more data"}, safe=True) + yield from test.insert({"src":"Wordpress", "content":"blog article 1"}, safe=True) + yield from test.insert({"src":"Wordpress", "content":"blog article 2"}, safe=True) + yield from test.insert({"src":"Wordpress", "content":"some comments"}, safe=True) + + # Read more about the aggregation pipeline in MongoDB's docs + pipeline = [ + {'$group': {'_id':'$src', 'content_list': {'$push': '$content'} } } + ] + result = yield from test.aggregate(pipeline) + + print("result:", result) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) \ No newline at end of file diff --git a/examples/dbref.py b/examples/dbref.py new file mode 100644 index 0000000..7ec461c --- /dev/null +++ b/examples/dbref.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo +from asyncio_mongo._bson import DBRef + + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + doc_a = {"username":"foo", "password":"bar"} + result = yield from test.insert(doc_a, safe=True) + + doc_b = {"settings":"foobar", "owner":DBRef("test", result)} + yield from test.insert(doc_b, safe=True) + + doc = yield from test.find_one({"settings":"foobar"}) + print("doc is:", doc) + + if isinstance(doc["owner"], DBRef): + ref = doc["owner"] + owner = yield from foo[ref.collection].find_one(ref.id) + print("owner:", owner) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/examples/drop.py b/examples/drop.py new file mode 100644 index 0000000..c38afbd --- /dev/null +++ b/examples/drop.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + result = yield from test.drop(safe=True) + print(result) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/examples/group.py b/examples/group.py new file mode 100644 index 0000000..5041119 --- /dev/null +++ b/examples/group.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + yield from test.insert({"src":"Twitter", "content":"bla bla"}, safe=True) + yield from test.insert({"src":"Twitter", "content":"more data"}, safe=True) + yield from test.insert({"src":"Wordpress", "content":"blog article 1"}, safe=True) + yield from test.insert({"src":"Wordpress", "content":"blog article 2"}, safe=True) + yield from test.insert({"src":"Wordpress", "content":"some comments"}, safe=True) + + result = yield from test.group(keys=["src"], + initial={"count":0}, reduce="function(obj,prev){prev.count++;}") + + print("result:", result) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/examples/index.py b/examples/index.py new file mode 100644 index 0000000..eef37d0 --- /dev/null +++ b/examples/index.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo +from asyncio_mongo import filter + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + idx = filter.sort(filter.ASCENDING("something") + filter.DESCENDING("else")) + print("IDX:", idx) + + result = yield from test.create_index(idx) + print("create_index:", result) + + result = yield from test.index_information() + print("index_information:", result) + + result = yield from test.drop_index(idx) + print("drop_index:", result) + + # Geohaystack example + geoh_idx = filter.sort(filter.GEOHAYSTACK("loc") + filter.ASCENDING("type")) + print("IDX:", geoh_idx) + result = yield from test.create_index(geoh_idx, **{'bucketSize':1}) + print("index_information:", result) + + result = yield from test.drop_index(geoh_idx) + print("drop_index:", result) + + # 2D geospatial index + geo_idx = filter.sort(filter.GEO2D("pos")) + print("IDX:", geo_idx) + result = yield from test.create_index(geo_idx, **{ 'min':-100, 'max':100 }) + print("index_information:", result) + + result = yield from test.drop_index(geo_idx) + print("drop_index:", result) + + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/examples/insert.py b/examples/insert.py new file mode 100644 index 0000000..d2b3284 --- /dev/null +++ b/examples/insert.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# coding: utf-8 + +import time +import asyncio +import asyncio_mongo + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + # insert some data + for x in range(10000): + result = yield from test.insert({"something":x*time.time()}, safe=True) + print(result) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) \ No newline at end of file diff --git a/examples/query.py b/examples/query.py new file mode 100644 index 0000000..44672dd --- /dev/null +++ b/examples/query.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + # fetch some documents + docs = yield from test.find(limit=10) + for doc in docs: + print(doc) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) \ No newline at end of file diff --git a/examples/query_fields.py b/examples/query_fields.py new file mode 100644 index 0000000..08124e1 --- /dev/null +++ b/examples/query_fields.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo +import asyncio_mongo.filter + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + # specify the fields to be returned by the query + # reference: http://www.mongodb.org/display/DOCS/Retrieving+a+Subset+of+Fields + whitelist = {'_id': 1, 'name': 1} + blacklist = {'_id': 0} + quickwhite = ['_id', 'name'] + + fields = blacklist + + # fetch some documents + docs = yield from test.find(limit=10, fields=fields) + for n, doc in enumerate(docs): + print(n, doc) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/examples/query_filter.py b/examples/query_filter.py new file mode 100644 index 0000000..8a6a503 --- /dev/null +++ b/examples/query_filter.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo +import asyncio_mongo.filter + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + # create the filter + f = asyncio_mongo.filter.sort(asyncio_mongo.filter.DESCENDING("something")) + #f += asyncio_mongo.filter.hint(asyncio_mongo.filter.DESCENDING("myindex")) + #f += asyncio_mongo.filter.explain() + + # fetch some documents + docs = yield from test.find(limit=10, filter=f) + for n, doc in enumerate(docs): + print(n, doc) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/examples/update.py b/examples/update.py new file mode 100644 index 0000000..69518b9 --- /dev/null +++ b/examples/update.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# coding: utf-8 +import asyncio + +import asyncio_mongo + +@asyncio.coroutine +def example(): + mongo = yield from asyncio_mongo.Connection.create('localhost', 27017) + + foo = mongo.foo # `foo` database + test = foo.test # `test` collection + + # insert + yield from test.insert({"foo":"bar", "name":"bla"}, safe=True) + + # update + result = yield from test.update({"foo":"bar"}, {"$set": {"name":"john doe"}}, safe=True) + print("result:", result) + +if __name__ == '__main__': + asyncio.get_event_loop().run_until_complete(example()) diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..026ab0a --- /dev/null +++ b/setup.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +import sys +import os +import shutil + +from setuptools import setup +from setuptools import Feature +from distutils.cmd import Command +from distutils.command.build_ext import build_ext +from distutils.errors import CCompilerError +from distutils.errors import DistutilsPlatformError, DistutilsExecError +from distutils.core import Extension + +requirements = ["asyncio"] +try: + import xml.etree.ElementTree +except ImportError: + requirements.append("elementtree") + + +if sys.platform == 'win32' and sys.version_info > (2, 6): + # 2.6's distutils.msvc9compiler can raise an IOError when failing to + # find the compiler + build_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError, + IOError) +else: + build_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError) + + +class custom_build_ext(build_ext): + """Allow C extension building to fail. + + The C extension speeds up BSON encoding, but is not essential. + """ + + warning_message = """ +************************************************************** +WARNING: %s could not +be compiled. No C extensions are essential for PyMongo to run, +although they do result in significant speed improvements. + +%s +************************************************************** +""" + + def run(self): + try: + build_ext.run(self) + except DistutilsPlatformError as e: + print(e) + print(self.warning_message % ("Extension modules", + "There was an issue with your platform configuration - see above.")) + + def build_extension(self, ext): + if sys.version_info[:3] >= (2, 4, 0): + try: + build_ext.build_extension(self, ext) + except build_errors as e: + print(e) + print(self.warning_message % ("The %s extension module" % ext.name, + "Above is the ouput showing how " + "the compilation failed.")) + else: + print(self.warning_message % ("The %s extension module" % ext.name, + "Please use Python >= 2.4 to take " + "advantage of the extension.")) + +c_ext = Feature( + "optional C extension", + standard=True, + ext_modules=[Extension('txmongo._pymongo._cbson', + include_dirs=['txmongo/_pymongo'], + sources=['txmongo/_pymongo/_cbsonmodule.c', + 'txmongo/_pymongo/time_helpers.c', + 'txmongo/_pymongo/encoding_helpers.c'])]) + +if "--no_ext" in sys.argv: + sys.argv = [x for x in sys.argv if x != "--no_ext"] + features = {} +else: + features = {"c-ext": c_ext} + +setup( + name="asyncio-mongo", + version="0.1.0", + description="Asynchronous Python 3.3+ driver for MongoDB ", + author="Alexandre Fiori, Don Brown", + author_email="mrdon@twdata.org", + url="https://bitbucket.org/mrdon/asyncio-mongo", + keywords=["mongo", "mongodb", "pymongo", "gridfs", "asyncio_mongo", "asyncio"], + packages=["asyncio_mongo", "asyncio_mongo._pymongo", "asyncio_mongo._gridfs", "asyncio_mongo._bson"], + install_requires=requirements, + features=features, + license="Apache License, Version 2.0", + test_suite="nose.collector", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Programming Language :: Python", + "Topic :: Database"], + cmdclass={"build_ext": custom_build_ext, + "doc": ""})