diff --git a/bson/__init__.py b/bson/__init__.py index 972b184015..c85c8a9a41 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1106,9 +1106,21 @@ def _decode_all( _decode_all = _cbson._decode_all # noqa: F811 +@overload +def decode_all(data: "_ReadableBuffer", codec_options: None = None) -> "List[Dict[str, Any]]": + ... + + +@overload def decode_all( - data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None + data: "_ReadableBuffer", codec_options: "CodecOptions[_DocumentType]" ) -> "List[_DocumentType]": + ... + + +def decode_all( + data: "_ReadableBuffer", codec_options: "Optional[CodecOptions[_DocumentType]]" = None +) -> "Union[List[Dict[str, Any]], List[_DocumentType]]": """Decode BSON data to multiple documents. `data` must be a bytes-like object implementing the buffer protocol that @@ -1131,11 +1143,13 @@ def decode_all( Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with `codec_options`. """ - opts = codec_options or DEFAULT_CODEC_OPTIONS - if not isinstance(opts, CodecOptions): + if codec_options is None: + return _decode_all(data, DEFAULT_CODEC_OPTIONS) + + if not isinstance(codec_options, CodecOptions): raise _CODEC_OPTIONS_TYPE_ERROR - return _decode_all(data, opts) # type:ignore[arg-type] + return _decode_all(data, codec_options) def _decode_selective(rawdoc: Any, fields: Any, codec_options: Any) -> Mapping[Any, Any]: @@ -1242,9 +1256,21 @@ def _decode_all_selective(data: Any, codec_options: CodecOptions, fields: Any) - ] +@overload +def decode_iter(data: bytes, codec_options: None = None) -> "Iterator[Dict[str, Any]]": + ... + + +@overload def decode_iter( - data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None + data: bytes, codec_options: "CodecOptions[_DocumentType]" ) -> "Iterator[_DocumentType]": + ... + + +def decode_iter( + data: bytes, codec_options: "Optional[CodecOptions[_DocumentType]]" = None +) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]": """Decode BSON data to multiple documents as a generator. Works similarly to the decode_all function, but yields one document at a @@ -1278,9 +1304,23 @@ def decode_iter( yield _bson_to_dict(elements, opts) +@overload def decode_file_iter( - file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None + file_obj: Union[BinaryIO, IO], codec_options: None = None +) -> "Iterator[Dict[str, Any]]": + ... + + +@overload +def decode_file_iter( + file_obj: Union[BinaryIO, IO], codec_options: "CodecOptions[_DocumentType]" ) -> "Iterator[_DocumentType]": + ... + + +def decode_file_iter( + file_obj: Union[BinaryIO, IO], codec_options: "Optional[CodecOptions[_DocumentType]]" = None +) -> "Union[Iterator[Dict[str, Any]], Iterator[_DocumentType]]": """Decode bson data from a file to multiple documents as a generator. Works similarly to the decode_all function, but reads from the file object diff --git a/pymongo/collection.py b/pymongo/collection.py index fbbe7fb593..5db2f33777 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -427,24 +427,6 @@ def database(self) -> Database[_DocumentType]: """ return self.__database - # @overload - # def with_options( - # self, - # codec_options: None = None, - # read_preference: Optional[_ServerMode] = None, - # write_concern: Optional[WriteConcern] = None, - # read_concern: Optional[ReadConcern] = None, - # ) -> Collection[Dict[str, Any]]: ... - - # @overload - # def with_options( - # self, - # codec_options: bson.CodecOptions[_DocumentType], - # read_preference: Optional[_ServerMode] = None, - # write_concern: Optional[WriteConcern] = None, - # read_concern: Optional[ReadConcern] = None, - # ) -> Collection[_DocumentType]: ... - def with_options( self, codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None, diff --git a/test/test_typing.py b/test/test_typing.py index 27597bb2c8..b2db4b93b9 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -242,6 +242,11 @@ def foo(self): rt_document3 = decode(bsonbytes2, codec_options=codec_options2) assert rt_document3.raw + def test_bson_decode_no_codec_option(self) -> None: + doc = decode_all(encode({"a": 1})) + assert doc + doc[0]["a"] = 2 + def test_bson_decode_all(self) -> None: doc = {"_id": 1} bsonbytes = encode(doc) @@ -266,6 +271,15 @@ def foo(self): rt_documents3 = decode_all(bsonbytes3, codec_options3) assert rt_documents3[0].raw + def test_bson_decode_all_no_codec_option(self) -> None: + docs = decode_all(b"") + docs.append({"new": 1}) + + docs = decode_all(encode({"a": 1})) + assert docs + docs[0]["a"] = 2 + docs.append({"new": 1}) + def test_bson_decode_iter(self) -> None: doc = {"_id": 1} bsonbytes = encode(doc) @@ -290,6 +304,11 @@ def foo(self): rt_documents3 = decode_iter(bsonbytes3, codec_options3) assert next(rt_documents3).raw + def test_bson_decode_iter_no_codec_option(self) -> None: + doc = next(decode_iter(encode({"a": 1}))) + assert doc + doc["a"] = 2 + def make_tempfile(self, content: bytes) -> Any: fileobj = tempfile.TemporaryFile() fileobj.write(content) @@ -324,6 +343,12 @@ def foo(self): rt_documents3 = decode_file_iter(fileobj3, codec_options3) assert next(rt_documents3).raw + def test_bson_decode_file_iter_none_codec_option(self) -> None: + fileobj = self.make_tempfile(encode({"new": 1})) + doc = next(decode_file_iter(fileobj)) + assert doc + doc["a"] = 2 + class TestDocumentType(unittest.TestCase): @only_type_check