Skip to content

Commit

Permalink
Merge pull request #300 from IlyaSkriblovsky/master
Browse files Browse the repository at this point in the history
Builder pattern methods on Cursor
  • Loading branch information
psi29a authored Nov 5, 2024
2 parents 7e614fb + faa91f6 commit e977fb9
Show file tree
Hide file tree
Showing 8 changed files with 579 additions and 332 deletions.
4 changes: 4 additions & 0 deletions docs/source/NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ API Changes
- `find()` method now returns `Cursor()` instance that can be used as async generator to
asynchronously iterate over results. It can still be used as Deferred too, so this change
is backward-compatible.
- `Cursor()` options can be by chaining its methods, for example:
::
async for doc in collection.find({"size": "L"}).sort({"price": 1}).limit(10).skip(5):
print(doc)
- `find_with_cursor()` is deprecated and will be removed in the next release.


Expand Down
30 changes: 27 additions & 3 deletions tests/basic/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def cmp(a, b):
return (a > b) - (a < b)


class TestIndexInfo(unittest.TestCase):
class TestCollectionMethods(unittest.TestCase):

timeout = 5

Expand All @@ -54,7 +54,7 @@ def tearDown(self):
yield self.conn.disconnect()

@defer.inlineCallbacks
def test_collection(self):
def test_type_checking(self):
self.assertRaises(TypeError, Collection, self.db, 5)

def make_col(base, name):
Expand All @@ -76,6 +76,7 @@ def make_col(base, name):
self.assertRaises(TypeError, self.db.test.find, projection="test")
self.assertRaises(TypeError, self.db.test.find, skip="test")
self.assertRaises(TypeError, self.db.test.find, limit="test")
self.assertRaises(TypeError, self.db.test.find, batch_size="test")
self.assertRaises(TypeError, self.db.test.find, sort="test")
self.assertRaises(TypeError, self.db.test.find, skip="test")
self.assertRaises(TypeError, self.db.test.insert_many, [1])
Expand Down Expand Up @@ -105,9 +106,32 @@ def make_col(base, name):
options = yield self.db.test.options()
self.assertTrue(isinstance(options, dict))

@defer.inlineCallbacks
def test_collection_names(self):
coll_names = [f"coll_{i}" for i in range(10)]
yield defer.gatherResults(
self.db[name].insert_one({"x": 1}) for name in coll_names
)

try:
names = yield self.db.collection_names()
self.assertEqual(set(coll_names), set(names))
names = yield self.db.collection_names(batch_size=10)
self.assertEqual(set(coll_names), set(names))
finally:
yield defer.gatherResults(self.db[name].drop() for name in coll_names)

test_collection_names.timeout = 1500

@defer.inlineCallbacks
def test_drop_collection(self):
yield self.db.test.insert_one({"x": 1})
collection_names = yield self.db.collection_names()
self.assertIn("test", collection_names)

yield self.db.drop_collection("test")
collection_names = yield self.db.collection_names()
self.assertFalse("test" in collection_names)
self.assertNotIn("test", collection_names)

@defer.inlineCallbacks
def test_create_index(self):
Expand Down
103 changes: 58 additions & 45 deletions tests/basic/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import asynccontextmanager, contextmanager

from pymongo.errors import OperationFailure
from twisted.internet import defer
Expand Down Expand Up @@ -39,19 +40,43 @@ def tearDown(self):
yield self.db.system.profile.drop()
yield self.conn.disconnect()

@defer.inlineCallbacks
def test_Hint(self):
@asynccontextmanager
async def _assert_single_command_with_option(self, optionname, optionvalue):
# Checking that `optionname` appears in profiler log with specified value

await self.db.command("profile", 2)
yield
await self.db.command("profile", 0)

profile_filter = {"command." + optionname: optionvalue}
cnt = await self.db.system.profile.count(profile_filter)
await self.db.system.profile.drop()
self.assertEqual(cnt, 1)

async def test_Hint(self):
# find() should fail with 'bad hint' if hint specifier works correctly
self.assertFailure(
self.coll.find({}, sort=qf.hint([("x", 1)])), OperationFailure
)
self.assertFailure(self.coll.find().hint({"x": 1}), OperationFailure)

# create index and test it is honoured
yield self.coll.create_index(qf.sort(qf.ASCENDING("x")), name="test_index")
found_1 = yield self.coll.find({}, sort=qf.hint([("x", 1)]))
found_2 = yield self.coll.find({}, sort=qf.hint(qf.ASCENDING("x")))
found_3 = yield self.coll.find({}, sort=qf.hint("test_index"))
self.assertTrue(found_1 == found_2 == found_3)
await self.coll.create_index(qf.sort(qf.ASCENDING("x")), name="test_index")
forms = [
[("x", 1)],
{"x": 1},
qf.ASCENDING("x"),
]
for form in forms:
async with self._assert_single_command_with_option("hint", {"x": 1}):
await self.coll.find({}, sort=qf.hint(form))
async with self._assert_single_command_with_option("hint", {"x": 1}):
await self.coll.find().hint(form)

async with self._assert_single_command_with_option("hint", "test_index"):
await self.coll.find({}, sort=qf.hint("test_index"))
async with self._assert_single_command_with_option("hint", "test_index"):
await self.coll.find().hint("test_index")

# find() should fail with 'bad hint' if hint specifier works correctly
self.assertFailure(
Expand All @@ -67,13 +92,18 @@ def test_SortAscendingMultipleFields(self):
qf.sort(qf.ASCENDING(["x", "y"])),
qf.sort(qf.ASCENDING("x") + qf.ASCENDING("y")),
)
self.assertEqual(
qf.sort(qf.ASCENDING(["x", "y"])),
qf.sort({"x": 1, "y": 1}),
)

def test_SortOneLevelList(self):
self.assertEqual(qf.sort([("x", 1)]), qf.sort(("x", 1)))

def test_SortInvalidKey(self):
self.assertRaises(TypeError, qf.sort, [(1, 2)])
self.assertRaises(TypeError, qf.sort, [("x", 3)])
self.assertRaises(TypeError, qf.sort, {"x": 3})

def test_SortGeoIndexes(self):
self.assertEqual(qf.sort(qf.GEO2D("x")), qf.sort([("x", "2d")]))
Expand All @@ -83,45 +113,33 @@ def test_SortGeoIndexes(self):
def test_TextIndex(self):
self.assertEqual(qf.sort(qf.TEXT("title")), qf.sort([("title", "text")]))

def __3_2_or_higher(self):
return self.db.command("buildInfo").addCallback(
lambda info: info["versionArray"] >= [3, 2]
)

def __3_6_or_higher(self):
return self.db.command("buildInfo").addCallback(
lambda info: info["versionArray"] >= [3, 6]
)

@defer.inlineCallbacks
def __test_simple_filter(self, filter, optionname, optionvalue):
# Checking that `optionname` appears in profiler log with specified value

yield self.db.command("profile", 2)
yield self.coll.find({}, sort=filter)
yield self.db.command("profile", 0)

if (yield self.__3_6_or_higher()):
profile_filter = {"command." + optionname: optionvalue}
elif (yield self.__3_2_or_higher()):
# query options format in system.profile have changed in MongoDB 3.2
profile_filter = {"query." + optionname: optionvalue}
else:
profile_filter = {"query.$" + optionname: optionvalue}

cnt = yield self.db.system.profile.count(profile_filter)
self.assertEqual(cnt, 1)

@defer.inlineCallbacks
def test_Comment(self):
async def test_SortProfile(self):
forms = [
qf.DESCENDING("x"),
{"x": -1},
[("x", -1)],
("x", -1),
]
for form in forms:
async with self._assert_single_command_with_option("sort.x", -1):
await self.coll.find({}, sort=qf.sort(form))
async with self._assert_single_command_with_option("sort.x", -1):
await self.coll.find().sort(form)

async def test_Comment(self):
comment = "hello world"

yield self.__test_simple_filter(qf.comment(comment), "comment", comment)
async with self._assert_single_command_with_option("comment", comment):
await self.coll.find({}, sort=qf.comment(comment))
async with self._assert_single_command_with_option("comment", comment):
await self.coll.find().comment(comment)

@defer.inlineCallbacks
def test_Explain(self):
result = yield self.coll.find({}, sort=qf.explain())
self.assertTrue("executionStats" in result[0] or "nscanned" in result[0])
result = yield self.coll.find().explain()
self.assertTrue("executionStats" in result[0] or "nscanned" in result[0])

@defer.inlineCallbacks
def test_FilterMerge(self):
Expand All @@ -136,12 +154,7 @@ def test_FilterMerge(self):
yield self.coll.find({}, sort=qf.sort(qf.ASCENDING("x")) + qf.comment(comment))
yield self.db.command("profile", 0)

if (yield self.__3_6_or_higher()):
profile_filter = {"command.sort.x": 1, "command.comment": comment}
elif (yield self.__3_2_or_higher()):
profile_filter = {"query.sort.x": 1, "query.comment": comment}
else:
profile_filter = {"query.$orderby.x": 1, "query.$comment": comment}
profile_filter = {"command.sort.x": 1, "command.comment": comment}
cnt = yield self.db.system.profile.count(profile_filter)
self.assertEqual(cnt, 1)

Expand Down
Loading

0 comments on commit e977fb9

Please sign in to comment.