Skip to content

Commit cf8f8b0

Browse files
committed
a large number of type changes
1 parent 032271e commit cf8f8b0

File tree

11 files changed

+51
-20
lines changed

11 files changed

+51
-20
lines changed

doc/examples/type_hints.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,29 @@ These methods automatically add an "_id" field.
114114
>>> # This will not be type checked, despite being present, because it is added by PyMongo.
115115
>>> assert type(result["_id"]) == ObjectId
116116

117+
This same typing scheme works for all of the insert methods (`insert_one`, `insert_many`, and `bulk_write`). For `bulk_write`,
118+
both `InsertOne/Many` and `ReplaceOne/Many` operators are generic.
119+
120+
.. doctest::
121+
122+
>>> from typing import TypedDict
123+
>>> from pymongo import MongoClient
124+
>>> from pymongo.operations import InsertOne
125+
>>> from pymongo.collection import Collection
126+
>>> class Movie(TypedDict):
127+
... name: str
128+
... year: int
129+
...
130+
>>> client: MongoClient = MongoClient()
131+
>>> collection: Collection[Movie] = client.test.test
132+
>>> inserted = collection.bulk_write([InsertOne(Movie(name="Jurassic Park", year=1993))])
133+
>>> result = collection.find_one({"name": "Jurassic Park"})
134+
>>> assert result is not None
135+
>>> assert result["year"] == 1993
136+
>>> # This will not be type checked, despite being present, because it is added by PyMongo.
137+
>>> assert type(result["_id"]) == ObjectId
138+
139+
117140
Typed Database
118141
--------------
119142

test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def print_thread_stacks(pid: int) -> None:
10901090
class IntegrationTest(PyMongoTestCase):
10911091
"""Base class for TestCases that need a connection to MongoDB to pass."""
10921092

1093-
client: MongoClient
1093+
client: MongoClient[dict]
10941094
db: Database
10951095
credentials: Dict[str, str]
10961096

test/mockupdb/test_cluster_time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def callback(client):
6060
self.cluster_time_conversation(callback, [{"ok": 1}] * 2)
6161

6262
def test_bulk(self):
63-
def callback(client):
63+
def callback(client: MongoClient[dict]) -> None:
6464
client.db.collection.bulk_write(
6565
[InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})]
6666
)

test/mockupdb/test_op_msg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,22 @@
137137
# Legacy methods
138138
Operation(
139139
"bulk_write_insert",
140-
lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]),
140+
lambda coll: coll.bulk_write([InsertOne[dict]({}), InsertOne[dict]({})]),
141141
request=OpMsg({"insert": "coll"}, flags=0),
142142
reply={"ok": 1, "n": 2},
143143
),
144144
Operation(
145145
"bulk_write_insert-w0",
146146
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
147-
[InsertOne({}), InsertOne({})]
147+
[InsertOne[dict]({}), InsertOne[dict]({})]
148148
),
149149
request=OpMsg({"insert": "coll"}, flags=0),
150150
reply={"ok": 1, "n": 2},
151151
),
152152
Operation(
153153
"bulk_write_insert-w0-unordered",
154154
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
155-
[InsertOne({}), InsertOne({})], ordered=False
155+
[InsertOne[dict]({}), InsertOne[dict]({})], ordered=False
156156
),
157157
request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]),
158158
reply=None,

test/test_bulk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_upsert(self):
296296
def test_numerous_inserts(self):
297297
# Ensure we don't exceed server's maxWriteBatchSize size limit.
298298
n_docs = client_context.max_write_batch_size + 100
299-
requests = [InsertOne({}) for _ in range(n_docs)]
299+
requests = [InsertOne[dict]({}) for _ in range(n_docs)]
300300
result = self.coll.bulk_write(requests, ordered=False)
301301
self.assertEqual(n_docs, result.inserted_count)
302302
self.assertEqual(n_docs, self.coll.count_documents({}))
@@ -347,7 +347,7 @@ def test_bulk_write_no_results(self):
347347

348348
def test_bulk_write_invalid_arguments(self):
349349
# The requests argument must be a list.
350-
generator = (InsertOne({}) for _ in range(10))
350+
generator = (InsertOne[dict]({}) for _ in range(10))
351351
with self.assertRaises(TypeError):
352352
self.coll.bulk_write(generator) # type: ignore[arg-type]
353353

test/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,6 +1652,7 @@ def test_network_error_message(self):
16521652
with self.fail_point(
16531653
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
16541654
):
1655+
assert client.address is not None
16551656
expected = "%s:%s: " % client.address
16561657
with self.assertRaisesRegex(AutoReconnect, expected):
16571658
client.pymongo_test.test.find_one({})

test/test_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_list_collection_names_filter(self):
201201
db.capped.insert_one({})
202202
db.non_capped.insert_one({})
203203
self.addCleanup(client.drop_database, db.name)
204-
204+
filter: dict
205205
# Should not send nameOnly.
206206
for filter in ({"options.capped": True}, {"options.capped": True, "name": "capped"}):
207207
results.clear()

test/test_server_selection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def all_hosts_started():
8585
)
8686

8787
wait_until(all_hosts_started, "receive heartbeat from all hosts")
88-
expected_port = max([n.address[1] for n in client._topology._description.readable_servers])
88+
expected_port = max(
89+
[n.address[1] for n in client._topology._description.readable_servers]
90+
) # type:ignore[type-var]
8991

9092
# Insert 1 record and access it 10 times.
9193
coll.insert_one({"name": "John Doe"})

test/test_session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,9 @@ def _test_writes(self, op):
898898

899899
@client_context.require_no_standalone
900900
def test_writes(self):
901-
self._test_writes(lambda coll, session: coll.bulk_write([InsertOne({})], session=session))
901+
self._test_writes(
902+
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
903+
)
902904
self._test_writes(lambda coll, session: coll.insert_one({}, session=session))
903905
self._test_writes(lambda coll, session: coll.insert_many([{}], session=session))
904906
self._test_writes(
@@ -944,7 +946,7 @@ def _test_no_read_concern(self, op):
944946
@client_context.require_no_standalone
945947
def test_writes_do_not_include_read_concern(self):
946948
self._test_no_read_concern(
947-
lambda coll, session: coll.bulk_write([InsertOne({})], session=session)
949+
lambda coll, session: coll.bulk_write([InsertOne[dict]({})], session=session)
948950
)
949951
self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session))
950952
self._test_no_read_concern(lambda coll, session: coll.insert_many([{}], session=session))
@@ -1077,7 +1079,9 @@ def setUp(self):
10771079
def test_cluster_time(self):
10781080
listener = SessionTestListener()
10791081
# Prevent heartbeats from updating $clusterTime between operations.
1080-
client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999)
1082+
client: MongoClient[dict] = rs_or_single_client(
1083+
event_listeners=[listener], heartbeatFrequencyMS=999999
1084+
)
10811085
self.addCleanup(client.close)
10821086
collection = client.pymongo_test.collection
10831087
# Prepare for tests of find() and aggregate().

test/test_transactions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def test_transaction_direct_connection(self):
363363
coll.insert_one({})
364364
self.assertEqual(client.topology_description.topology_type_name, "Single")
365365
ops = [
366-
(coll.bulk_write, [[InsertOne({})]]),
366+
(coll.bulk_write, [[InsertOne[dict]({})]]),
367367
(coll.insert_one, [{}]),
368368
(coll.insert_many, [[{}, {}]]),
369369
(coll.replace_one, [{}, {}]),
@@ -385,7 +385,7 @@ def test_transaction_direct_connection(self):
385385
]
386386
for f, args in ops:
387387
with client.start_session() as s, s.start_transaction():
388-
res = f(*args, session=s)
388+
res = f(*args, session=s) # type:ignore[operator]
389389
if isinstance(res, (CommandCursor, Cursor)):
390390
list(res)
391391

test/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from collections import abc, defaultdict
3030
from functools import partial
3131
from test import client_context, db_pwd, db_user
32+
from typing import Any
3233

3334
from bson import json_util
3435
from bson.objectid import ObjectId
@@ -557,35 +558,35 @@ def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs
557558
return MongoClient(uri, port, **client_options)
558559

559560

560-
def single_client_noauth(h=None, p=None, **kwargs):
561+
def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
561562
"""Make a direct connection. Don't authenticate."""
562563
return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs)
563564

564565

565-
def single_client(h=None, p=None, **kwargs):
566+
def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
566567
"""Make a direct connection, and authenticate if necessary."""
567568
return _mongo_client(h, p, directConnection=True, **kwargs)
568569

569570

570-
def rs_client_noauth(h=None, p=None, **kwargs):
571+
def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
571572
"""Connect to the replica set. Don't authenticate."""
572573
return _mongo_client(h, p, authenticate=False, **kwargs)
573574

574575

575-
def rs_client(h=None, p=None, **kwargs):
576+
def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
576577
"""Connect to the replica set and authenticate if necessary."""
577578
return _mongo_client(h, p, **kwargs)
578579

579580

580-
def rs_or_single_client_noauth(h=None, p=None, **kwargs):
581+
def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
581582
"""Connect to the replica set if there is one, otherwise the standalone.
582583
583584
Like rs_or_single_client, but does not authenticate.
584585
"""
585586
return _mongo_client(h, p, authenticate=False, **kwargs)
586587

587588

588-
def rs_or_single_client(h=None, p=None, **kwargs):
589+
def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]:
589590
"""Connect to the replica set if there is one, otherwise the standalone.
590591
591592
Authenticates if necessary.

0 commit comments

Comments
 (0)