Skip to content

Commit 984acf4

Browse files
blink1073juliusgeo
authored andcommitted
PYTHON-3064 Add typings to test package (mongodb#844)
1 parent 7fa22fa commit 984acf4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+542
-261
lines changed

.github/workflows/test-python.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,5 @@ jobs:
4646
pip install -e ".[zstd, srv]"
4747
- name: Run mypy
4848
run: |
49-
mypy --install-types --non-interactive bson gridfs tools
49+
mypy --install-types --non-interactive bson gridfs tools pymongo
50+
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index test

bson/son.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
# This is essentially the same as re._pattern_type
2929
RE_TYPE: Type[Pattern[Any]] = type(re.compile(""))
3030

31-
_Key = TypeVar("_Key", bound=str)
31+
_Key = TypeVar("_Key")
3232
_Value = TypeVar("_Value")
3333
_T = TypeVar("_T")
3434

mypy.ini

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ warn_unused_configs = true
1111
warn_unused_ignores = true
1212
warn_redundant_casts = true
1313

14+
[mypy-gevent.*]
15+
ignore_missing_imports = True
16+
1417
[mypy-kerberos.*]
1518
ignore_missing_imports = True
1619

@@ -29,5 +32,12 @@ ignore_missing_imports = True
2932
[mypy-snappy.*]
3033
ignore_missing_imports = True
3134

35+
[mypy-test.*]
36+
allow_redefinition = true
37+
allow_untyped_globals = true
38+
3239
[mypy-winkerberos.*]
3340
ignore_missing_imports = True
41+
42+
[mypy-xmlrunner.*]
43+
ignore_missing_imports = True

pymongo/socket_checker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616

1717
import errno
1818
import select
19-
import socket
2019
import sys
21-
from typing import Any, Optional
20+
from typing import Any, Optional, Union
2221

2322
# PYTHON-2320: Jython does not fully support poll on SSL sockets,
2423
# https://bugs.jython.org/issue2900
@@ -43,7 +42,7 @@ def __init__(self) -> None:
4342
else:
4443
self._poller = None
4544

46-
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: int = 0) -> bool:
45+
def select(self, sock: Any, read: bool = False, write: bool = False, timeout: Optional[float] = 0) -> bool:
4746
"""Select for reads or writes with a timeout in seconds (or None).
4847
4948
Returns True if the socket is readable/writable, False on timeout.

pymongo/srv_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def maybe_decode(text):
3939
def _resolve(*args, **kwargs):
4040
if hasattr(resolver, 'resolve'):
4141
# dnspython >= 2
42-
return resolver.resolve(*args, **kwargs) # type: ignore
42+
return resolver.resolve(*args, **kwargs)
4343
# dnspython 1.X
4444
return resolver.query(*args, **kwargs)
4545

pymongo/typings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Type aliases used by PyMongo"""
1616
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, MutableMapping, Optional,
17-
Tuple, Type, TypeVar, Union)
17+
Sequence, Tuple, Type, TypeVar, Union)
1818

1919
if TYPE_CHECKING:
2020
from bson.raw_bson import RawBSONDocument
@@ -25,5 +25,5 @@
2525
_Address = Tuple[str, Optional[int]]
2626
_CollationIn = Union[Mapping[str, Any], "Collation"]
2727
_DocumentIn = Union[MutableMapping[str, Any], "RawBSONDocument"]
28-
_Pipeline = List[Mapping[str, Any]]
28+
_Pipeline = Sequence[Mapping[str, Any]]
2929
_DocumentType = TypeVar('_DocumentType', Mapping[str, Any], MutableMapping[str, Any], Dict[str, Any])

test/__init__.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from contextlib import contextmanager
4242
from functools import wraps
43+
from typing import Dict, no_type_check
4344
from unittest import SkipTest
4445

4546
import pymongo
@@ -48,7 +49,9 @@
4849
from bson.son import SON
4950
from pymongo import common, message
5051
from pymongo.common import partition_node
52+
from pymongo.database import Database
5153
from pymongo.hello import HelloCompat
54+
from pymongo.mongo_client import MongoClient
5255
from pymongo.server_api import ServerApi
5356
from pymongo.ssl_support import HAVE_SSL, _ssl
5457
from pymongo.uri_parser import parse_uri
@@ -86,7 +89,7 @@
8689
os.path.join(CERT_PATH, 'client.pem'))
8790
CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem'))
8891

89-
TLS_OPTIONS = dict(tls=True)
92+
TLS_OPTIONS: Dict = dict(tls=True)
9093
if CLIENT_PEM:
9194
TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM
9295
if CA_PEM:
@@ -102,13 +105,13 @@
102105
# Remove after PYTHON-2712
103106
from pymongo import pool
104107
pool._MOCK_SERVICE_ID = True
105-
res = parse_uri(SINGLE_MONGOS_LB_URI)
108+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
106109
host, port = res['nodelist'][0]
107110
db_user = res['username'] or db_user
108111
db_pwd = res['password'] or db_pwd
109112
elif TEST_SERVERLESS:
110113
TEST_LOADBALANCER = True
111-
res = parse_uri(SINGLE_MONGOS_LB_URI)
114+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
112115
host, port = res['nodelist'][0]
113116
db_user = res['username'] or db_user
114117
db_pwd = res['password'] or db_pwd
@@ -184,6 +187,7 @@ def enable(self):
184187
def __enter__(self):
185188
self.enable()
186189

190+
@no_type_check
187191
def disable(self):
188192
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
189193
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
@@ -224,6 +228,8 @@ def _all_users(db):
224228

225229

226230
class ClientContext(object):
231+
client: MongoClient
232+
227233
MULTI_MONGOS_LB_URI = MULTI_MONGOS_LB_URI
228234

229235
def __init__(self):
@@ -247,9 +253,9 @@ def __init__(self):
247253
self.tls = False
248254
self.tlsCertificateKeyFile = False
249255
self.server_is_resolvable = is_server_resolvable()
250-
self.default_client_options = {}
256+
self.default_client_options: Dict = {}
251257
self.sessions_enabled = False
252-
self.client = None
258+
self.client = None # type: ignore
253259
self.conn_lock = threading.Lock()
254260
self.is_data_lake = False
255261
self.load_balancer = TEST_LOADBALANCER
@@ -340,6 +346,7 @@ def _init_client(self):
340346
try:
341347
self.cmd_line = self.client.admin.command('getCmdLineOpts')
342348
except pymongo.errors.OperationFailure as e:
349+
assert e.details is not None
343350
msg = e.details.get('errmsg', '')
344351
if e.code == 13 or 'unauthorized' in msg or 'login' in msg:
345352
# Unauthorized.
@@ -418,6 +425,7 @@ def _init_client(self):
418425
else:
419426
self.server_parameters = self.client.admin.command(
420427
'getParameter', '*')
428+
assert self.cmd_line is not None
421429
if 'enableTestCommands=1' in self.cmd_line['argv']:
422430
self.test_commands_enabled = True
423431
elif 'parsed' in self.cmd_line:
@@ -436,7 +444,8 @@ def _init_client(self):
436444
self.mongoses.append(address)
437445
if not self.serverless:
438446
# Check for another mongos on the next port.
439-
next_address = address[0], address[1] + 1
447+
assert address is not None
448+
next_address = address[0], address[1] + 1
440449
mongos_client = self._connect(
441450
*next_address, **self.default_client_options)
442451
if mongos_client:
@@ -496,6 +505,7 @@ def _check_user_provided(self):
496505
try:
497506
return db_user in _all_users(client.admin)
498507
except pymongo.errors.OperationFailure as e:
508+
assert e.details is not None
499509
msg = e.details.get('errmsg', '')
500510
if e.code == 18 or 'auth fails' in msg:
501511
# Auth failed.
@@ -505,6 +515,7 @@ def _check_user_provided(self):
505515

506516
def _server_started_with_auth(self):
507517
# MongoDB >= 2.0
518+
assert self.cmd_line is not None
508519
if 'parsed' in self.cmd_line:
509520
parsed = self.cmd_line['parsed']
510521
# MongoDB >= 2.6
@@ -525,6 +536,7 @@ def _server_started_with_ipv6(self):
525536
if not socket.has_ipv6:
526537
return False
527538

539+
assert self.cmd_line is not None
528540
if 'parsed' in self.cmd_line:
529541
if not self.cmd_line['parsed'].get('net', {}).get('ipv6'):
530542
return False
@@ -932,6 +944,9 @@ def fail_point(self, command_args):
932944

933945
class IntegrationTest(PyMongoTestCase):
934946
"""Base class for TestCases that need a connection to MongoDB to pass."""
947+
client: MongoClient
948+
db: Database
949+
credentials: Dict[str, str]
935950

936951
@classmethod
937952
@client_context.require_connection
@@ -1073,7 +1088,7 @@ def run(self, test):
10731088

10741089

10751090
if HAVE_XML:
1076-
class PymongoXMLTestRunner(XMLTestRunner):
1091+
class PymongoXMLTestRunner(XMLTestRunner): # type: ignore[misc]
10771092
def run(self, test):
10781093
setup()
10791094
result = super(PymongoXMLTestRunner, self).run(test)

test/auth_aws/test_auth_aws.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
class TestAuthAWS(unittest.TestCase):
29+
uri: str
2930

3031
@classmethod
3132
def setUpClass(cls):

test/mockupdb/test_cursor_namespace.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222

2323
class TestCursorNamespace(unittest.TestCase):
24+
server: MockupDB
25+
client: MongoClient
26+
2427
@classmethod
2528
def setUpClass(cls):
2629
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})
@@ -69,6 +72,9 @@ def op():
6972

7073

7174
class TestKillCursorsNamespace(unittest.TestCase):
75+
server: MockupDB
76+
client: MongoClient
77+
7278
@classmethod
7379
def setUpClass(cls):
7480
cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6})

test/mockupdb/test_getmore_sharded.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_getmore_sharded(self):
2727
servers = [MockupDB(), MockupDB()]
2828

2929
# Collect queries to either server in one queue.
30-
q = Queue()
30+
q: Queue = Queue()
3131
for server in servers:
3232
server.subscribe(q.put)
3333
server.autoresponds('ismaster', ismaster=True, msg='isdbgrid',

test/mockupdb/test_handshake.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def respond(r):
4848
ServerApiVersion.V1))}
4949
client = MongoClient("mongodb://"+primary.address_string,
5050
appname='my app', # For _check_handshake_data()
51-
**dict([k_map.get((k, v), (k, v)) for k, v
51+
**dict([k_map.get((k, v), (k, v)) for k, v # type: ignore[arg-type]
5252
in kwargs.items()]))
5353

5454
self.addCleanup(client.close)
@@ -58,7 +58,7 @@ def respond(r):
5858

5959
# We do this checking here rather than in the autoresponder `respond()`
6060
# because it runs in another Python thread so there are some funky things
61-
# with error handling within that thread, and we want to be able to use
61+
# with error handling within that thread, and we want to be able to use
6262
# self.assertRaises().
6363
self.handshake_req.assert_matches(protocol(hello, **kwargs))
6464
_check_handshake_data(self.handshake_req)

test/mockupdb/test_mixed_version_sharded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def setup_server(self, upgrade):
3030
self.mongos_old, self.mongos_new = MockupDB(), MockupDB()
3131

3232
# Collect queries to either server in one queue.
33-
self.q = Queue()
33+
self.q: Queue = Queue()
3434
for server in self.mongos_old, self.mongos_new:
3535
server.subscribe(self.q.put)
3636
server.autoresponds('getlasterror')
@@ -59,7 +59,7 @@ def create_mixed_version_sharded_test(upgrade):
5959
def test(self):
6060
self.setup_server(upgrade)
6161
start = time.time()
62-
servers_used = set()
62+
servers_used: set = set()
6363
while len(servers_used) < 2:
6464
go(upgrade.function, self.client)
6565
request = self.q.get(timeout=1)

test/mockupdb/test_op_msg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@
233233

234234

235235
class TestOpMsg(unittest.TestCase):
236+
server: MockupDB
237+
client: MongoClient
236238

237239
@classmethod
238240
def setUpClass(cls):

test/mockupdb/test_op_msg_read_preference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
import itertools
17+
from typing import Any
1718

1819
from mockupdb import MockupDB, going, CommandBase
1920
from pymongo import MongoClient, ReadPreference
@@ -27,6 +28,8 @@
2728

2829
class OpMsgReadPrefBase(unittest.TestCase):
2930
single_mongod = False
31+
primary: MockupDB
32+
secondary: MockupDB
3033

3134
@classmethod
3235
def setUpClass(cls):
@@ -142,7 +145,7 @@ def test(self):
142145
tag_sets=None)
143146

144147
client = self.setup_client(read_preference=pref)
145-
148+
expected_pref: Any
146149
if operation.op_type == 'always-use-secondary':
147150
expected_server = self.secondary
148151
expected_pref = ReadPreference.SECONDARY

test/mockupdb/test_reset_and_request_check.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class TestResetAndRequestCheck(unittest.TestCase):
2828
def __init__(self, *args, **kwargs):
2929
super(TestResetAndRequestCheck, self).__init__(*args, **kwargs)
30-
self.ismaster_time = 0
30+
self.ismaster_time = 0.0
3131
self.client = None
3232
self.server = None
3333

@@ -45,7 +45,7 @@ def responder(request):
4545
kwargs = {'socketTimeoutMS': 100}
4646
# Disable retryable reads when pymongo supports it.
4747
kwargs['retryReads'] = False
48-
self.client = MongoClient(self.server.uri, **kwargs)
48+
self.client = MongoClient(self.server.uri, **kwargs) # type: ignore
4949
wait_until(lambda: self.client.nodes, 'connect to standalone')
5050

5151
def tearDown(self):
@@ -56,6 +56,8 @@ def _test_disconnect(self, operation):
5656
# Application operation fails. Test that client resets server
5757
# description and does *not* schedule immediate check.
5858
self.setup_server()
59+
assert self.server is not None
60+
assert self.client is not None
5961

6062
# Network error on application operation.
6163
with self.assertRaises(ConnectionFailure):
@@ -81,6 +83,8 @@ def _test_timeout(self, operation):
8183
# Application operation times out. Test that client does *not* reset
8284
# server description and does *not* schedule immediate check.
8385
self.setup_server()
86+
assert self.server is not None
87+
assert self.client is not None
8488

8589
with self.assertRaises(ConnectionFailure):
8690
with going(operation.function, self.client):
@@ -91,6 +95,7 @@ def _test_timeout(self, operation):
9195
# Server is *not* Unknown.
9296
topology = self.client._topology
9397
server = topology.select_server_by_address(self.server.address, 0)
98+
assert server is not None
9499
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)
95100

96101
after = self.ismaster_time
@@ -99,6 +104,8 @@ def _test_timeout(self, operation):
99104
def _test_not_master(self, operation):
100105
# Application operation gets a "not master" error.
101106
self.setup_server()
107+
assert self.server is not None
108+
assert self.client is not None
102109

103110
with self.assertRaises(ConnectionFailure):
104111
with going(operation.function, self.client):
@@ -110,6 +117,7 @@ def _test_not_master(self, operation):
110117
# Server is rediscovered.
111118
topology = self.client._topology
112119
server = topology.select_server_by_address(self.server.address, 0)
120+
assert server is not None
113121
self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type)
114122

115123
after = self.ismaster_time

0 commit comments

Comments
 (0)