Skip to content

Commit dd6c140

Browse files
authored
PYTHON-3060 Add typings to pymongo package (#831)
1 parent abfa0d3 commit dd6c140

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1579
-1100
lines changed

bson/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@
6161
import struct
6262
import sys
6363
import uuid
64-
from codecs import utf_8_decode as _utf_8_decode # type: ignore
65-
from codecs import utf_8_encode as _utf_8_encode # type: ignore
64+
from codecs import utf_8_decode as _utf_8_decode # type: ignore[attr-defined]
65+
from codecs import utf_8_encode as _utf_8_encode # type: ignore[attr-defined]
6666
from collections import abc as _abc
67-
from typing import (TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator,
67+
from typing import (IO, TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Generator,
6868
Iterator, List, Mapping, MutableMapping, NoReturn,
6969
Sequence, Tuple, Type, TypeVar, Union, cast)
7070

@@ -88,11 +88,13 @@
8888

8989
# Import RawBSONDocument for type-checking only to avoid circular dependency.
9090
if TYPE_CHECKING:
91+
from array import array
92+
from mmap import mmap
9193
from bson.raw_bson import RawBSONDocument
9294

9395

9496
try:
95-
from bson import _cbson # type: ignore
97+
from bson import _cbson # type: ignore[attr-defined]
9698
_USE_C = True
9799
except ImportError:
98100
_USE_C = False
@@ -851,6 +853,7 @@ def _datetime_to_millis(dtm: datetime.datetime) -> int:
851853

852854
_DocumentIn = Mapping[str, Any]
853855
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
856+
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
854857

855858

856859
def encode(document: _DocumentIn, check_keys: bool = False, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> bytes:
@@ -880,7 +883,7 @@ def encode(document: _DocumentIn, check_keys: bool = False, codec_options: Codec
880883
return _dict_to_bson(document, check_keys, codec_options)
881884

882885

883-
def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut:
886+
def decode(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]:
884887
"""Decode BSON to a document.
885888
886889
By default, returns a BSON document represented as a Python
@@ -912,7 +915,7 @@ def decode(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) ->
912915
return _bson_to_dict(data, codec_options)
913916

914917

915-
def decode_all(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[_DocumentOut]:
918+
def decode_all(data: _ReadableBuffer, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> List[Dict[str, Any]]:
916919
"""Decode BSON data to multiple documents.
917920
918921
`data` must be a bytes-like object implementing the buffer protocol that
@@ -1075,7 +1078,7 @@ def decode_iter(data: bytes, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS
10751078
yield _bson_to_dict(elements, codec_options)
10761079

10771080

1078-
def decode_file_iter(file_obj: BinaryIO, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]:
1081+
def decode_file_iter(file_obj: Union[BinaryIO, IO], codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Iterator[_DocumentOut]:
10791082
"""Decode bson data from a file to multiple documents as a generator.
10801083
10811084
Works similarly to the decode_all function, but reads from the file object
@@ -1158,7 +1161,7 @@ def encode(cls: Type["BSON"], document: _DocumentIn, check_keys: bool = False,
11581161
"""
11591162
return cls(encode(document, check_keys, codec_options))
11601163

1161-
def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> _DocumentOut: # type: ignore[override]
1164+
def decode(self, codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS) -> Dict[str, Any]: # type: ignore[override]
11621165
"""Decode this BSON data.
11631166
11641167
By default, returns a BSON document represented as a Python

bson/binary.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Tuple, Type
15+
from typing import Any, Tuple, Type, Union, TYPE_CHECKING
1616
from uuid import UUID
1717

1818
"""Tools for representing BSON binary data.
@@ -57,6 +57,11 @@
5757
"""
5858

5959

60+
if TYPE_CHECKING:
61+
from array import array as _array
62+
from mmap import mmap as _mmap
63+
64+
6065
class UuidRepresentation:
6166
UNSPECIFIED = 0
6267
"""An unspecified UUID representation.
@@ -211,7 +216,7 @@ class Binary(bytes):
211216
_type_marker = 5
212217
__subtype: int
213218

214-
def __new__(cls: Type["Binary"], data: bytes, subtype: int = BINARY_SUBTYPE) -> "Binary":
219+
def __new__(cls: Type["Binary"], data: Union[memoryview, bytes, "_mmap", "_array"], subtype: int = BINARY_SUBTYPE) -> "Binary":
215220
if not isinstance(subtype, int):
216221
raise TypeError("subtype must be an instance of int")
217222
if subtype >= 256 or subtype < 0:

gridfs/grid_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,10 +874,10 @@ def next(self) -> GridOut:
874874

875875
__next__ = next
876876

877-
def add_option(self, *args: Any, **kwargs: Any) -> None:
877+
def add_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
878878
raise NotImplementedError("Method does not exist for GridOutCursor")
879879

880-
def remove_option(self, *args: Any, **kwargs: Any) -> None:
880+
def remove_option(self, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
881881
raise NotImplementedError("Method does not exist for GridOutCursor")
882882

883883
def _clone_base(self, session: ClientSession) -> "GridOutCursor":

mypy.ini

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,33 @@
11
[mypy]
2+
check_untyped_defs = true
23
disallow_subclassing_any = true
34
disallow_incomplete_defs = true
45
no_implicit_optional = true
6+
pretty = true
7+
show_error_context = true
8+
show_error_codes = true
59
strict_equality = true
610
warn_unused_configs = true
711
warn_unused_ignores = true
812
warn_redundant_casts = true
913

14+
[mypy-kerberos.*]
15+
ignore_missing_imports = True
16+
1017
[mypy-mockupdb]
1118
ignore_missing_imports = True
19+
20+
[mypy-pymongo_auth_aws.*]
21+
ignore_missing_imports = True
22+
23+
[mypy-pymongocrypt.*]
24+
ignore_missing_imports = True
25+
26+
[mypy-service_identity.*]
27+
ignore_missing_imports = True
28+
29+
[mypy-snappy.*]
30+
ignore_missing_imports = True
31+
32+
[mypy-winkerberos.*]
33+
ignore_missing_imports = True

pymongo/__init__.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Python driver for MongoDB."""
1616

17+
from typing import Tuple, Union
18+
1719
ASCENDING = 1
1820
"""Ascending sort order."""
1921
DESCENDING = -1
@@ -53,35 +55,33 @@
5355
.. _text index: http://docs.mongodb.org/manual/core/index-text/
5456
"""
5557

56-
version_tuple = (4, 1, 0, '.dev0')
58+
version_tuple: Tuple[Union[int, str], ...] = (4, 1, 0, '.dev0')
5759

58-
def get_version_string():
60+
def get_version_string() -> str:
5961
if isinstance(version_tuple[-1], str):
6062
return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1]
6163
return '.'.join(map(str, version_tuple))
6264

63-
__version__ = version = get_version_string()
65+
__version__: str = get_version_string()
66+
version = __version__
67+
6468
"""Current version of PyMongo."""
6569

6670
from pymongo.collection import ReturnDocument
67-
from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION,
68-
MAX_SUPPORTED_WIRE_VERSION)
71+
from pymongo.common import (MAX_SUPPORTED_WIRE_VERSION,
72+
MIN_SUPPORTED_WIRE_VERSION)
6973
from pymongo.cursor import CursorType
7074
from pymongo.mongo_client import MongoClient
71-
from pymongo.operations import (IndexModel,
72-
InsertOne,
73-
DeleteOne,
74-
DeleteMany,
75-
UpdateOne,
76-
UpdateMany,
77-
ReplaceOne)
75+
from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, InsertOne,
76+
ReplaceOne, UpdateMany, UpdateOne)
7877
from pymongo.read_preferences import ReadPreference
7978
from pymongo.write_concern import WriteConcern
8079

81-
def has_c():
80+
81+
def has_c() -> bool:
8282
"""Is the C extension installed?"""
8383
try:
84-
from pymongo import _cmessage
84+
from pymongo import _cmessage # type: ignore[attr-defined]
8585
return True
8686
except ImportError:
8787
return False

pymongo/aggregation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
"""Perform aggregation operations on a collection or database."""
1616

1717
from bson.son import SON
18-
1918
from pymongo import common
2019
from pymongo.collation import validate_collation_or_none
2120
from pymongo.errors import ConfigurationError
22-
from pymongo.read_preferences import _AggWritePref, ReadPreference
21+
from pymongo.read_preferences import ReadPreference, _AggWritePref
2322

2423

2524
class _AggregationCommand(object):
@@ -37,7 +36,7 @@ def __init__(self, target, cursor_class, pipeline, options,
3736

3837
self._target = target
3938

40-
common.validate_list('pipeline', pipeline)
39+
pipeline = common.validate_list('pipeline', pipeline)
4140
self._pipeline = pipeline
4241
self._performs_write = False
4342
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
@@ -82,7 +81,6 @@ def _cursor_namespace(self):
8281
"""The namespace in which the aggregate command is run."""
8382
raise NotImplementedError
8483

85-
@property
8684
def _cursor_collection(self, cursor_doc):
8785
"""The Collection used for the aggregate command cursor."""
8886
raise NotImplementedError

pymongo/auth.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import hmac
2020
import os
2121
import socket
22-
2322
from base64 import standard_b64decode, standard_b64encode
2423
from collections import namedtuple
24+
from typing import Callable, Mapping
2525
from urllib.parse import quote
2626

2727
from bson.binary import Binary
@@ -97,7 +97,7 @@ def __hash__(self):
9797
"""Mechanism properties for GSSAPI authentication."""
9898

9999

100-
_AWSProperties = namedtuple('AWSProperties', ['aws_session_token'])
100+
_AWSProperties = namedtuple('_AWSProperties', ['aws_session_token'])
101101
"""Mechanism properties for MONGODB-AWS authentication."""
102102

103103

@@ -140,9 +140,9 @@ def _build_credentials_tuple(mech, source, user, passwd, extra, database):
140140

141141
properties = extra.get('authmechanismproperties', {})
142142
aws_session_token = properties.get('AWS_SESSION_TOKEN')
143-
props = _AWSProperties(aws_session_token=aws_session_token)
143+
aws_props = _AWSProperties(aws_session_token=aws_session_token)
144144
# user can be None for temporary link-local EC2 credentials.
145-
return MongoCredential(mech, '$external', user, passwd, props, None)
145+
return MongoCredential(mech, '$external', user, passwd, aws_props, None)
146146
elif mech == 'PLAIN':
147147
source_database = source or database or '$external'
148148
return MongoCredential(mech, source_database, user, passwd, None, None)
@@ -471,7 +471,7 @@ def _authenticate_default(credentials, sock_info):
471471
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1')
472472

473473

474-
_AUTH_MAP = {
474+
_AUTH_MAP: Mapping[str, Callable] = {
475475
'GSSAPI': _authenticate_gssapi,
476476
'MONGODB-CR': _authenticate_mongo_cr,
477477
'MONGODB-X509': _authenticate_x509,
@@ -532,7 +532,7 @@ def speculate_command(self):
532532
return cmd
533533

534534

535-
_SPECULATIVE_AUTH_MAP = {
535+
_SPECULATIVE_AUTH_MAP: Mapping[str, Callable] = {
536536
'MONGODB-X509': _X509Context,
537537
'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'),
538538
'SCRAM-SHA-256': functools.partial(_ScramContext,
@@ -544,6 +544,6 @@ def speculate_command(self):
544544
def authenticate(credentials, sock_info):
545545
"""Authenticate sock_info."""
546546
mechanism = credentials.mechanism
547-
auth_func = _AUTH_MAP.get(mechanism)
547+
auth_func = _AUTH_MAP[mechanism]
548548
auth_func(credentials, sock_info)
549549

pymongo/auth_aws.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616

1717
try:
1818
import pymongo_auth_aws
19-
from pymongo_auth_aws import (AwsCredential,
20-
AwsSaslContext,
19+
from pymongo_auth_aws import (AwsCredential, AwsSaslContext,
2120
PyMongoAuthAwsError)
2221
_HAVE_MONGODB_AWS = True
2322
except ImportError:
24-
class AwsSaslContext(object):
23+
class AwsSaslContext(object): # type: ignore
2524
def __init__(self, credentials):
2625
pass
2726
_HAVE_MONGODB_AWS = False
@@ -32,7 +31,7 @@ def __init__(self, credentials):
3231
from pymongo.errors import ConfigurationError, OperationFailure
3332

3433

35-
class _AwsSaslContext(AwsSaslContext):
34+
class _AwsSaslContext(AwsSaslContext): # type: ignore
3635
# Dependency injection:
3736
def binary_type(self):
3837
"""Return the bson.binary.Binary type."""

pymongo/bulk.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,23 @@
1717
.. versionadded:: 2.7
1818
"""
1919
import copy
20-
2120
from itertools import islice
2221

2322
from bson.objectid import ObjectId
2423
from bson.raw_bson import RawBSONDocument
2524
from bson.son import SON
2625
from pymongo.client_session import _validate_session_write_concern
27-
from pymongo.common import (validate_is_mapping,
28-
validate_is_document_type,
29-
validate_ok_for_replace,
30-
validate_ok_for_update)
31-
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
3226
from pymongo.collation import validate_collation_or_none
33-
from pymongo.errors import (BulkWriteError,
34-
ConfigurationError,
35-
InvalidOperation,
36-
OperationFailure)
37-
from pymongo.message import (_INSERT, _UPDATE, _DELETE,
38-
_randint,
39-
_BulkWriteContext,
40-
_EncryptedBulkWriteContext)
27+
from pymongo.common import (validate_is_document_type, validate_is_mapping,
28+
validate_ok_for_replace, validate_ok_for_update)
29+
from pymongo.errors import (BulkWriteError, ConfigurationError,
30+
InvalidOperation, OperationFailure)
31+
from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc
32+
from pymongo.message import (_DELETE, _INSERT, _UPDATE, _BulkWriteContext,
33+
_EncryptedBulkWriteContext, _randint)
4134
from pymongo.read_preferences import ReadPreference
4235
from pymongo.write_concern import WriteConcern
4336

44-
4537
_DELETE_ALL = 0
4638
_DELETE_ONE = 1
4739

0 commit comments

Comments
 (0)