Files

229 lines
8.3 KiB
Python

# coding: utf-8
# Copyright 2009 Alexandre Fiori
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import struct
import asyncio
from asyncio_mongo.database import Database
from asyncio_mongo.exceptions import ConnectionLostError
import asyncio_mongo._bson as bson
from asyncio_mongo.log import logger
_ONE = b"\x01\x00\x00\x00"
_ZERO = b"\x00\x00\x00\x00"
_QUERY_OPTIONS = {
"tailable_cursor": 2,
"oplog_replay": 8,
"await_data": 32}
"""Low level connection to Mongo."""
class _MongoQuery(object):
def __init__(self, id, collection, limit, tailable=False):
self.id = id
self.limit = limit
self.collection = collection
self.tailable = tailable
self.documents = []
self.future = asyncio.Future()
class MongoProtocol(asyncio.Protocol):
def __init__(self, connection_lost_callback=None, authenticators=None):
self.__id = 0
self.__buffer = b""
self.__queries = {}
self.__datalen = None
self.__response = 0
self.__waiting_header = True
self.__connection_lost_callback = connection_lost_callback
self._pipelined_calls = set() # Set of all the pipelined calls.
self.transport = None
self._is_connected = False
self.__authenticators = authenticators or {}
def connection_made(self, transport):
self.transport = transport
self._is_connected = True
logger.log(logging.INFO, 'Mongo connection made')
# for name, auth in self.__authenticators.iter_items():
# yield from auth(Database(self, name))
# logger.log(logging.INFO, 'Authenticated to database {name}'.format(name=name))
def connection_lost(self, exc):
self._is_connected = False
self.transport = None
# Raise exception on all waiting futures.
for f in self.__queries.values():
f.set_exception(ConnectionLostError(exc))
logger.log(logging.INFO, 'Mongo connection lost')
if self.__connection_lost_callback:
self.__connection_lost_callback(exc)
@property
def is_connected(self):
""" True when the underlying transport is connected. """
return self._is_connected
def data_received(self, data):
while self.__waiting_header:
self.__buffer += data
if len(self.__buffer) < 16:
break
# got full header, 16 bytes (or more)
header, extra = self.__buffer[:16], self.__buffer[16:]
self.__buffer = b""
self.__waiting_header = False
datalen, request, response, operation = struct.unpack("<iiii", header)
self.__datalen = datalen - 16
self.__response = response
if extra:
self.data_received(extra)
break
else:
if self.__datalen is not None:
data, extra = data[:self.__datalen], data[self.__datalen:]
self.__datalen -= len(data)
else:
extra = b""
self.__buffer += data
if self.__datalen == 0:
self.message_received(self.__response, self.__buffer)
self.__datalen = None
self.__waiting_header = True
self.__buffer = b""
if extra:
self.data_received(extra)
def message_received(self, request_id, packet):
# Response Flags:
# bit 0: Cursor Not Found
# bit 1: Query Failure
# bit 2: Shard Config Stale
# bit 3: Await Capable
# bit 4-31: Reserved
QUERY_FAILURE = 1 << 1
response_flag, cursor_id, start, length = struct.unpack("<iqii", packet[:20])
if response_flag == QUERY_FAILURE:
self.query_failure(request_id, cursor_id, response_flag, bson.BSON(packet[20:]).decode())
return
self.query_success(request_id, cursor_id, bson.decode_all(packet[20:]))
def send_message(self, operation, collection, message, query_opts=_ZERO):
#print "sending %d to %s" % (operation, self)
fullname = collection and bson._make_c_string(collection) or b""
message = query_opts + fullname + message
# 16 is the size of the header in bytes
header = struct.pack("<iiii", 16 + len(message), self.__id, 0, operation)
self.transport.write(header + message)
self.__id += 1
def OP_INSERT(self, collection, docs):
docs = [bson.BSON.encode(doc) for doc in docs]
self.send_message(2002, collection, b"".join(docs))
def OP_UPDATE(self, collection, spec, document, upsert=False, multi=False):
options = 0
if upsert:
options += 1
if multi:
options += 2
message = struct.pack("<i", options) + \
bson.BSON.encode(spec) + bson.BSON.encode(document)
self.send_message(2001, collection, message)
def OP_DELETE(self, collection, spec):
self.send_message(2006, collection, _ZERO + bson.BSON.encode(spec))
def OP_KILL_CURSORS(self, cursors):
message = struct.pack("<i", len(cursors))
for cursor_id in cursors:
message += struct.pack("<q", cursor_id)
self.send_message(2007, None, message)
def OP_GET_MORE(self, collection, limit, cursor_id):
message = struct.pack("<iq", limit, cursor_id)
self.send_message(2005, collection, message)
def OP_QUERY(self, collection, spec, skip, limit, fields=None,
tailable=False, await_data=False):
message = struct.pack("<ii", skip, limit) + bson.BSON.encode(spec)
if fields:
message += bson.BSON.encode(fields)
query_opts = 0
if tailable:
query_opts |= _QUERY_OPTIONS["tailable_cursor"]
if tailable and await_data:
query_opts |= _QUERY_OPTIONS["await_data"]
query_opts = struct.pack("<i", query_opts)
query = _MongoQuery(self.__id, collection, limit, tailable)
self.__queries[self.__id] = query
self.send_message(2004, collection, message, query_opts)
return query.future
def query_failure(self, request_id, cursor_id, response, raw_error):
query = self.__queries.pop(request_id, None)
if query:
query.future.set_exception(ValueError("mongo error=%s" % repr(raw_error)))
del query
def query_success(self, request_id, cursor_id, documents):
try:
query = self.__queries.pop(request_id)
except KeyError:
return
if query.tailable:
query.documents = []
if isinstance(documents, list):
query.documents += documents
else:
query.documents.append(documents)
if cursor_id:
query.id = self.__id
next_batch = 0
if query.limit:
next_batch = query.limit - len(query.documents)
# Assert, because according to the protocol spec and my observations
# there should be no problems with this, but who knows? At least it will
# be noticed, if something unexpected happens. And it is definitely
# better, than silently returning a wrong number of documents
assert next_batch >= 0, "Unexpected number of documents received!"
if not next_batch:
self.OP_KILL_CURSORS([cursor_id])
query.future.set_result(query.documents)
return
self.__queries[self.__id] = query
self.OP_GET_MORE(query.collection, next_batch, cursor_id)
if query.tailable and query.documents:
f = asyncio.Future()
_f = query.future
query.future = f
_f.set_result((f, documents))
else:
query.future.set_result(query.documents)