diff --git a/asyncio_mongo/connection.py b/asyncio_mongo/connection.py index e3c1cfa..5999512 100644 --- a/asyncio_mongo/connection.py +++ b/asyncio_mongo/connection.py @@ -48,6 +48,11 @@ class Connection: return connection + @asyncio.coroutine + def disconnect(self): + if self.transport: + return self.transport.close() + @property def transport(self): """ The transport instance that the protocol is currently using. """ diff --git a/asyncio_mongo/pool.py b/asyncio_mongo/pool.py index 8ac11de..115b079 100644 --- a/asyncio_mongo/pool.py +++ b/asyncio_mongo/pool.py @@ -12,12 +12,12 @@ class Pool: Pool of connections. Each Takes care of setting up the connection and connection pooling. - When poolsize > 1 and some connections are in use because of transactions + When pool_size > 1 and some connections are in use because of transactions or blocking requests, the other are preferred. :: - pool = yield from Pool.create(host='localhost', port=6379, poolsize=10) + pool = yield from Pool.create(host='localhost', port=6379, pool_size=10) result = yield from connection.set('key', 'value') """ @@ -38,19 +38,19 @@ class Pool: @classmethod @asyncio.coroutine - def create(cls, host='localhost', port=6379, loop=None, password=None, db=0, poolsize=1, auto_reconnect=True): + def create(cls, host='localhost', port=6379, loop=None, password=None, db=0, pool_size=1, auto_reconnect=True): """ Create a new connection instance. """ self = cls() self._host = host self._port = port - self._poolsize = poolsize + self._pool_size = pool_size # Create connections self._connections = [] - for i in range(poolsize): + for i in range(pool_size): connection_class = cls.get_connection_class() connection = yield from connection_class.create(host=host, port=port, loop=loop, password=password, db=db, auto_reconnect=auto_reconnect) @@ -59,10 +59,10 @@ class Pool: return self def __repr__(self): - return 'Pool(host=%r, port=%r, poolsize=%r)' % (self._host, self._port, self._poolsize) + return 'Pool(host=%r, port=%r, pool_size=%r)' % (self._host, self._port, self._poolsize) @property - def poolsize(self): + def pool_size(self): """ Number of parallel connections in the pool.""" return self._poolsize @@ -80,6 +80,10 @@ class Pool: """ return sum([ 1 for c in self._connections if c.protocol.is_connected ]) + def close(self): + for conn in self._connections: + conn.disconnect() + def _get_free_connection(self): """ Return the next protocol instance that's not in use. @@ -94,7 +98,7 @@ class Pool: def _shuffle_connections(self): """ - 'shuffle' protocols. Make sure that we devide the load equally among the protocols. + 'shuffle' protocols. Make sure that we divide the load equally among the protocols. """ self._connections = self._connections[1:] + self._connections[:1] @@ -103,10 +107,14 @@ class Pool: Proxy to a protocol. (This will choose a protocol instance that's not busy in a blocking request or transaction.) """ + + if 'close' == name: + return self.close + connection = self._get_free_connection() if connection: return getattr(connection, name) else: raise NoAvailableConnectionsInPoolError('No available connections in the pool: size=%s, in_use=%s, connected=%s' % ( - self.poolsize, self.connections_in_use, self.connections_connected)) + self.pool_size, self.connections_in_use, self.connections_connected)) diff --git a/asyncio_mongo/protocol.py b/asyncio_mongo/protocol.py index 5f8e90f..c42fbf7 100644 --- a/asyncio_mongo/protocol.py +++ b/asyncio_mongo/protocol.py @@ -27,7 +27,7 @@ _ZERO = b"\x00\x00\x00\x00" class _MongoQuery(object): - def __init__(self, id, collection, limit): + def __init__(self, id, collection, limit): self.id = id self.limit = limit self.collection = collection diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 0000000..d6e9a4c --- /dev/null +++ b/tests/base.py @@ -0,0 +1,48 @@ +import inspect +import unittest +from asyncio import coroutine +import asyncio +import asyncio_mongo + +mongo_host = "localhost" +mongo_port = 27017 + + +def yields(value): + return isinstance(value, asyncio.futures.Future) or inspect.isgenerator(value) + + +@coroutine +def call_maybe_yield(func, *args, **kwargs): + rv = func(*args, **kwargs) + if yields(rv): + rv = yield from rv + return rv + + +def run_now(func, *args, **kwargs): + loop = asyncio.get_event_loop() + return loop.run_until_complete( + asyncio.Task(call_maybe_yield(func, *args, **kwargs)) + ) + + +def async(func): + def inner(*args, **kwargs): + run_now(func, *args, **kwargs) + return inner + + +class MongoTest(unittest.TestCase): + + @async + def setUp(self): + self.conn = yield from asyncio_mongo.Connection.create(mongo_host, mongo_port) + self.db = self.conn.mydb + self.coll = self.db.mycol + yield from self.coll.drop() + + @async + def tearDown(self): + yield from self.coll.drop() + self.conn.disconnect() \ No newline at end of file diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py new file mode 100644 index 0000000..ef256e1 --- /dev/null +++ b/tests/test_aggregate.py @@ -0,0 +1,46 @@ +# coding: utf-8 +# Copyright 2010 Tryggvi Bjorgvinsson +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tests.base import MongoTest, async + + +class TestAggregate(MongoTest): + + timeout = 5 + + @async + def test_aggregate(self): + yield from self.coll.insert([{'oh':'hai', 'lulz':123}, + {'oh':'kthxbye', 'lulz':456}, + {'oh':'hai', 'lulz':789},], safe=True) + + res = yield from self.coll.aggregate([ + {'$project': {'oh':1, 'lolz':'$lulz'}}, + {'$group': {'_id':'$oh', 'many_lolz': {'$sum':'$lolz'}}}, + {'$sort': {'_id':1}} + ]) + + self.assertEqual(len(res), 2) + self.assertEqual(res[0]['_id'], 'hai') + self.assertEqual(res[0]['many_lolz'], 912) + self.assertEqual(res[1]['_id'], 'kthxbye') + self.assertEqual(res[1]['many_lolz'], 456) + + res = yield from self.coll.aggregate([ + {'$match': {'oh':'hai'}} + ], full_response=True) + + self.assertIn('ok', res) + self.assertIn('result', res) + self.assertEqual(len(res['result']), 2) diff --git a/tests/test_collection.py b/tests/test_collection.py new file mode 100644 index 0000000..4342340 --- /dev/null +++ b/tests/test_collection.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- + +# Copyright 2012 Renzo S. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the collection module. +Based on pymongo driver's test_collection.py +""" +from asyncio_mongo._pymongo import errors +from asyncio_mongo.collection import Collection +from asyncio_mongo._bson.son import SON +from asyncio_mongo import filter + +from tests.base import MongoTest, async + + +class TestCollection(MongoTest): + + @async + def test_collection(self): + self.assertRaises(TypeError, Collection, self.db, 5) + + def make_col(base, name): + return base[name] + + self.assertRaises(errors.InvalidName, make_col, self.db, "") + self.assertRaises(errors.InvalidName, make_col, self.db, "te$t") + self.assertRaises(errors.InvalidName, make_col, self.db, ".test") + self.assertRaises(errors.InvalidName, make_col, self.db, "test.") + self.assertRaises(errors.InvalidName, make_col, self.db, "tes..t") + self.assertRaises(errors.InvalidName, make_col, self.db.test, "") + self.assertRaises(errors.InvalidName, make_col, self.db.test, "te$t") + self.assertRaises(errors.InvalidName, make_col, self.db.test, ".test") + self.assertRaises(errors.InvalidName, make_col, self.db.test, "test.") + self.assertRaises(errors.InvalidName, make_col, self.db.test, "tes..t") + self.assertRaises(errors.InvalidName, make_col, self.db.test, "tes\x00t") + + self.assert_(isinstance(self.db.test, Collection)) + self.assertEqual(self.db.test, Collection(self.db, "test")) + self.assertEqual(self.db.test.mike, self.db["test.mike"]) + self.assertEqual(self.db.test["mike"], self.db["test.mike"]) + + yield from self.db.drop_collection('test') + collection_names = yield from self.db.collection_names() + self.assertFalse('test' in collection_names) + + + @async + def test_create_index(self): + db = self.db + coll = self.coll + + self.assertRaises(TypeError, coll.create_index, 5) + self.assertRaises(TypeError, coll.create_index, {"hello": 1}) + + yield from coll.drop_indexes() + count = yield from db.system.indexes.count({"ns": u"mydb.mycol"}) + self.assertEqual(count, 1) + + result1 = yield from coll.create_index(filter.sort(filter.ASCENDING("hello"))) + result2 = yield from coll.create_index(filter.sort(filter.ASCENDING("hello") + \ + filter.DESCENDING("world"))) + + count = yield from db.system.indexes.count({"ns": u"mydb.mycol"}) + self.assertEqual(count, 3) + + yield from coll.drop_indexes() + ix = yield from coll.create_index(filter.sort(filter.ASCENDING("hello") + \ + filter.DESCENDING("world")), name="hello_world") + self.assertEquals(ix, "hello_world") + + yield from coll.drop_indexes() + count = yield from db.system.indexes.count({"ns": u"mydb.mycol"}) + self.assertEqual(count, 1) + + yield from coll.create_index(filter.sort(filter.ASCENDING("hello"))) + indices = yield from db.system.indexes.find({"ns": u"mydb.mycol"}) + self.assert_(u"hello_1" in [a["name"] for a in indices]) + + yield from coll.drop_indexes() + count = yield from db.system.indexes.count({"ns": u"mydb.mycol"}) + self.assertEqual(count, 1) + + ix = yield from coll.create_index(filter.sort(filter.ASCENDING("hello") + \ + filter.DESCENDING("world"))) + self.assertEquals(ix, "hello_1_world_-1") + + @async + def test_create_index_nodup(self): + coll = self.coll + + yield from coll.drop() + yield from coll.insert({'b': 1}) + yield from coll.insert({'b': 1}) + + self.assertRaises(errors.DuplicateKeyError, coll.create_index, filter.sort(filter.ASCENDING("b")), unique=True) + + + @async + def test_ensure_index(self): + db = self.db + coll = self.coll + + yield from coll.ensure_index(filter.sort(filter.ASCENDING("hello"))) + indices = yield from db.system.indexes.find({"ns": u"mydb.mycol"}) + self.assert_(u"hello_1" in [a["name"] for a in indices]) + + yield from coll.drop_indexes() + + @async + def test_index_info(self): + db = self.db + + yield from db.test.drop_indexes() + yield from db.test.remove({}) + + yield from db.test.save({}) # create collection + ix_info = yield from db.test.index_information() + self.assertEqual(len(ix_info), 1) + + self.assert_("_id_" in ix_info) + + yield from db.test.create_index(filter.sort(filter.ASCENDING("hello"))) + ix_info = yield from db.test.index_information() + self.assertEqual(len(ix_info), 2) + + self.assertEqual(ix_info["hello_1"], [("hello", 1)]) + + yield from db.test.create_index(filter.sort(filter.DESCENDING("hello") + filter.ASCENDING("world")), unique=True) + ix_info = yield from db.test.index_information() + + self.assertEqual(ix_info["hello_1"], [("hello", 1)]) + self.assertEqual(len(ix_info), 3) + self.assertEqual([("world", 1), ("hello", -1)], ix_info["hello_-1_world_1"]) + # Unique key will not show until index_information is updated with changes introduced in version 1.7 + #self.assertEqual(True, ix_info["hello_-1_world_1"]["unique"]) + + yield from db.test.drop_indexes() + yield from db.test.remove({}) + + + @async + def test_index_geo2d(self): + db = self.db + coll = self.coll + yield from coll.drop_indexes() + geo_ix = yield from coll.create_index(filter.sort(filter.GEO2D("loc"))) + + self.assertEqual('loc_2d', geo_ix) + + index_info = yield from coll.index_information() + self.assertEqual([('loc', '2d')], index_info['loc_2d']) + + @async + def test_index_haystack(self): + db = self.db + coll = self.coll + yield from coll.drop_indexes() + + _id = yield from coll.insert({ + "pos": {"long": 34.2, "lat": 33.3}, + "type": "restaurant" + }) + yield from coll.insert({ + "pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant" + }) + yield from coll.insert({ + "pos": {"long": 59.1, "lat": 87.2}, "type": "office" + }) + + yield from coll.create_index(filter.sort(filter.GEOHAYSTACK("pos") + filter.ASCENDING("type")), **{'bucket_size': 1}) + + # TODO: A db.command method has not been implemented yet. + # Sending command directly + command = SON([ + ("geoSearch", "mycol"), + ("near", [33, 33]), + ("maxDistance", 6), + ("search", {"type": "restaurant"}), + ("limit", 30), + ]) + + results = yield from db["$cmd"].find_one(command) + self.assertEqual(2, len(results['results'])) + self.assertEqual({ + "_id": _id, + "pos": {"long": 34.2, "lat": 33.3}, + "type": "restaurant" + }, results["results"][0]) + + diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 0000000..7836624 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# Copyright 2009 Alexandre Fiori +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +import asyncio_mongo +from tests.base import MongoTest, async + + +mongo_host = "localhost" +mongo_port = 27017 + + +class TestMongoConnectionMethods(MongoTest): + + @async + def test_connection(self): + # MongoConnection returns deferred, which gets MongoAPI + conn = asyncio_mongo.Connection.create(mongo_host, mongo_port) + self.assertTrue(inspect.isgenerator(conn)) + rapi = yield from conn + self.assertEqual(isinstance(rapi, asyncio_mongo.Connection), True) + rapi.disconnect() + + @async + def test_pool(self): + # MongoConnectionPool returns deferred, which gets MongoAPI + pool = asyncio_mongo.Pool.create(mongo_host, mongo_port, pool_size=2) + self.assertTrue(inspect.isgenerator(pool)) + rapi = yield from pool + print('rapi %s' % rapi.__class__) + self.assertEqual(isinstance(rapi, asyncio_mongo.Pool), True) + rapi.close() \ No newline at end of file diff --git a/tests/test_find_and_modify.py b/tests/test_find_and_modify.py new file mode 100644 index 0000000..75d2177 --- /dev/null +++ b/tests/test_find_and_modify.py @@ -0,0 +1,37 @@ +# coding: utf-8 +# Copyright 2010 Mark L. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tests.base import MongoTest, async + + +class TestFindAndModify(MongoTest): + @async + def test_update(self): + yield from self.coll.insert([{'oh': 'hai', 'lulz': 123}, + {'oh': 'kthxbye', 'lulz': 456}], safe=True) + + res = yield from self.coll.find_one({'oh': 'hai'}) + self.assertEqual(res['lulz'], 123) + + res = yield from self.coll.find_and_modify({'o2h': 'hai'}, {'$inc': {'lulz': 1}}) + self.assertEqual(res, None) + + res = yield from self.coll.find_and_modify({'oh': 'hai'}, {'$inc': {'lulz': 1}}) + print(res) + self.assertEqual(res['lulz'], 123) + res = yield from self.coll.find_and_modify({'oh': 'hai'}, {'$inc': {'lulz': 1}}, new=True) + self.assertEqual(res['lulz'], 125) + + res = yield from self.coll.find_one({'oh': 'kthxbye'}) + self.assertEqual(res['lulz'], 456) diff --git a/tests/test_objects.py b/tests/test_objects.py new file mode 100644 index 0000000..cad96d2 --- /dev/null +++ b/tests/test_objects.py @@ -0,0 +1,171 @@ +# coding: utf-8 +# Copyright 2009 Alexandre Fiori +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +from asyncio_mongo import database +from asyncio_mongo import collection +from asyncio_mongo import filter as qf +from asyncio_mongo._bson import objectid, timestamp +from tests.base import MongoTest, async + + +class TestMongoObjects(MongoTest): + @async + def test_MongoObjects(self): + """ Tests creating mongo objects """ + self.assertEqual(isinstance(self.db, database.Database), True) + self.assertEqual(isinstance(self.coll, collection.Collection), True) + + @async + def test_MongoOperations(self): + """ Tests mongo operations """ + test = self.coll + + # insert + doc = {"foo": "bar", "items": [1, 2, 3]} + yield from test.insert(doc, safe=True) + result = yield from test.find_one(doc) + self.assertEqual("_id" in result, True) + self.assertEqual(result["foo"], "bar") + self.assertEqual(result["items"], [1, 2, 3]) + + # insert preserves object id + doc.update({'_id': objectid.ObjectId()}) + yield from test.insert(doc, safe=True) + result = yield from test.find_one(doc) + self.assertEqual(result.get('_id'), doc.get('_id')) + self.assertEqual(result["foo"], "bar") + self.assertEqual(result["items"], [1, 2, 3]) + + # update + yield from test.update({"_id": result["_id"]}, {"$set": {"one": "two"}}, safe=True) + result = yield from test.find_one({"_id": result["_id"]}) + self.assertEqual(result["one"], "two") + + # delete + yield from test.remove(result["_id"], safe=True) + + @async + def test_Timestamps(self): + """Tests mongo operations with Timestamps""" + test = self.coll + + # insert with specific timestamp + doc1 = {'_id': objectid.ObjectId(), + 'ts': timestamp.Timestamp(1, 2)} + yield from test.insert(doc1, safe=True) + + result = yield from test.find_one(doc1) + self.assertEqual(result.get('ts').time, 1) + self.assertEqual(result.get('ts').inc, 2) + + # insert with specific timestamp + doc2 = {'_id': objectid.ObjectId(), + 'ts': timestamp.Timestamp(2, 1)} + yield from test.insert(doc2, safe=True) + + # the objects come back sorted by ts correctly. + # (test that we stored inc/time in the right fields) + result = yield from test.find(filter=qf.sort(qf.ASCENDING('ts'))) + self.assertEqual(result[0]['_id'], doc1['_id']) + self.assertEqual(result[1]['_id'], doc2['_id']) + + # insert with null timestamp + doc3 = {'_id': objectid.ObjectId(), + 'ts': timestamp.Timestamp(0, 0)} + yield from test.insert(doc3, safe=True) + + # time field loaded correctly + result = yield from test.find_one(doc3['_id']) + now = time.time() + self.assertTrue(now - 2 <= result['ts'].time <= now) + + # delete + yield from test.remove(doc1["_id"], safe=True) + yield from test.remove(doc2["_id"], safe=True) + yield from test.remove(doc3["_id"], safe=True) + + +# class TestGridFsObjects(unittest.TestCase): +# """ Test the GridFS operations from asyncio_mongo._gridfs """ +# @async +# def _disconnect(self, conn): +# """ Disconnect the connection """ +# yield from conn.disconnect() +# +# @async +# def test_GridFsObjects(self): +# """ Tests gridfs objects """ +# conn = yield from asyncio_mongo.MongoConnection(mongo_host, mongo_port) +# db = conn.test +# collection = db.fs +# +# gfs = gridfs.GridFS(db) # Default collection +# +# gridin = GridIn(collection, filename='test', contentType="text/plain", +# chunk_size=2**2**2**2) +# new_file = gfs.new_file(filename='test2', contentType="text/plain", +# chunk_size=2**2**2**2) +# +# # disconnect +# yield from conn.disconnect() +# +# @async +# def test_GridFsOperations(self): +# """ Tests gridfs operations """ +# conn = yield from asyncio_mongo.MongoConnection(mongo_host, mongo_port) +# db = conn.test +# collection = db.fs +# +# # Don't forget to disconnect +# self.addCleanup(self._disconnect, conn) +# try: +# in_file = StringIO("Test input string") +# out_file = StringIO() +# except Exception, e: +# self.fail("Failed to create memory files for testing: %s" % e) +# +# try: +# # Tests writing to a new gridfs file +# gfs = gridfs.GridFS(db) # Default collection +# g_in = gfs.new_file(filename='optest', contentType="text/plain", +# chunk_size=2**2**2**2) # non-default chunk size used +# # yielding to ensure writes complete before we close and close before we try to read +# yield from g_in.write(in_file.read()) +# yield from g_in.close() +# +# # Tests reading from an existing gridfs file +# g_out = yield from gfs.get_last_version('optest') +# data = yield from g_out.read() +# out_file.write(data) +# _id = g_out._id +# except Exception,e: +# self.fail("Failed to communicate with the GridFS. " + +# "Is MongoDB running? %s" % e) +# else: +# self.assertEqual(in_file.getvalue(), out_file.getvalue(), +# "Could not read the value from writing an input") +# finally: +# in_file.close() +# out_file.close() +# g_out.close() +# +# +# listed_files = yield from gfs.list() +# self.assertEqual(['optest'], listed_files, +# "'optest' is the only expected file and we received %s" % listed_files) +# +# yield from gfs.delete(_id) + diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 0000000..8ddb9bd --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,153 @@ +# coding: utf-8 +# Copyright 2010 Mark L. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from tests.base import async, MongoTest + + +class TestMongoQueries(MongoTest): + + @async + def test_SingleCursorIteration(self): + yield from self.coll.insert([{'v':i} for i in range(10)], safe=True) + res = yield from self.coll.find() + self.assertEqual(len(res), 10) + + @async + def test_MultipleCursorIterations(self): + yield from self.coll.insert([{'v':i} for i in range(450)], safe=True) + res = yield from self.coll.find() + self.assertEqual(len(res), 450) + + @async + def test_LargeData(self): + yield from self.coll.insert([{'v':' '*(2**19)} for i in range(4)], safe=True) + res = yield from self.coll.find() + self.assertEqual(len(res), 4) + + +class TestMongoQueriesEdgeCases(MongoTest): + + @async + def test_BelowBatchThreshold(self): + yield from self.coll.insert([{'v':i} for i in range(100)], safe=True) + res = yield from self.coll.find() + self.assertEqual(len(res), 100) + + @async + def test_EqualToBatchThreshold(self): + yield from self.coll.insert([{'v':i} for i in range(101)], safe=True) + res = yield from self.coll.find() + self.assertEqual(len(res), 101) + + @async + def test_AboveBatchThreshold(self): + yield from self.coll.insert([{'v':i} for i in range(102)], safe=True) + res = yield from self.coll.find() + self.assertEqual(len(res), 102) + + +class TestLimit(MongoTest): + + @async + def test_LimitBelowBatchThreshold(self): + yield from self.coll.insert([{'v':i} for i in range(50)], safe=True) + res = yield from self.coll.find(limit=20) + self.assertEqual(len(res), 20) + + @async + def test_LimitAboveBatchThreshold(self): + yield from self.coll.insert([{'v':i} for i in range(200)], safe=True) + res = yield from self.coll.find(limit=150) + self.assertEqual(len(res), 150) + + @async + def test_LimitAtBatchThresholdEdge(self): + yield from self.coll.insert([{'v':i} for i in range(200)], safe=True) + res = yield from self.coll.find(limit=100) + self.assertEqual(len(res), 100) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(200)], safe=True) + res = yield from self.coll.find(limit=101) + self.assertEqual(len(res), 101) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(200)], safe=True) + res = yield from self.coll.find(limit=102) + self.assertEqual(len(res), 102) + + @async + def test_LimitAboveMessageSizeThreshold(self): + yield from self.coll.insert([{'v':' '*(2**20)} for i in range(8)], safe=True) + res = yield from self.coll.find(limit=5) + self.assertEqual(len(res), 5) + + @async + def test_HardLimit(self): + yield from self.coll.insert([{'v':i} for i in range(200)], safe=True) + res = yield from self.coll.find(limit=-150) + self.assertEqual(len(res), 150) + + @async + def test_HardLimitAboveMessageSizeThreshold(self): + yield from self.coll.insert([{'v':' '*(2**20)} for i in range(8)], safe=True) + res = yield from self.coll.find(limit=-6) + self.assertEqual(len(res), 4) + + +class TestSkip(MongoTest): + + @async + def test_Skip(self): + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=3) + self.assertEqual(len(res), 2) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=5) + self.assertEqual(len(res), 0) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=6) + self.assertEqual(len(res), 0) + + @async + def test_SkipWithLimit(self): + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=3, limit=1) + self.assertEqual(len(res), 1) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=4, limit=2) + self.assertEqual(len(res), 1) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=4, limit=1) + self.assertEqual(len(res), 1) + + yield from self.coll.drop(safe=True) + + yield from self.coll.insert([{'v':i} for i in range(5)], safe=True) + res = yield from self.coll.find(skip=5, limit=1) + self.assertEqual(len(res), 0) \ No newline at end of file