diff --git a/bson/codec_options.py b/bson/codec_options.py index 833908fad0..0db900f59b 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -310,15 +310,9 @@ def with_options(self, **kwargs): .. versionadded:: 3.5 """ - return CodecOptions( - kwargs.get('document_class', self.document_class), - kwargs.get('tz_aware', self.tz_aware), - kwargs.get('uuid_representation', self.uuid_representation), - kwargs.get('unicode_decode_error_handler', - self.unicode_decode_error_handler), - kwargs.get('tzinfo', self.tzinfo), - kwargs.get('type_registry', self.type_registry) - ) + opts = self._asdict() + opts.update(kwargs) + return CodecOptions(**opts) DEFAULT_CODEC_OPTIONS = CodecOptions( diff --git a/bson/json_util.py b/bson/json_util.py index 7b789b0f30..f4c1b498f6 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -311,6 +311,26 @@ def _arguments_repr(self): self.json_mode, super(JSONOptions, self)._arguments_repr())) + def with_options(self, **kwargs): + """ + Make a copy of this JSONOptions, overriding some options:: + + >>> from bson.json_util import CANONICAL_JSON_OPTIONS + >>> CANONICAL_JSON_OPTIONS.tz_aware + True + >>> json_options = CANONICAL_JSON_OPTIONS.with_options(tz_aware=False) + >>> json_options.tz_aware + False + + .. versionadded:: 3.12 + """ + opts = self._asdict() + for opt in ('strict_number_long', 'datetime_representation', + 'strict_uuid', 'json_mode'): + opts[opt] = kwargs.get(opt, getattr(self, opt)) + opts.update(kwargs) + return JSONOptions(**opts) + LEGACY_JSON_OPTIONS = JSONOptions(json_mode=JSONMode.LEGACY) """:class:`JSONOptions` for encoding to PyMongo's legacy JSON format. diff --git a/test/test_json_util.py b/test/test_json_util.py index e8b64a16d1..7906b276f5 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -52,6 +52,34 @@ def round_trip(self, doc, **kwargs): def test_basic(self): self.round_trip({"hello": "world"}) + def test_json_options_with_options(self): + opts = json_util.JSONOptions( + datetime_representation=DatetimeRepresentation.NUMBERLONG) + self.assertEqual( + opts.datetime_representation, DatetimeRepresentation.NUMBERLONG) + opts2 = opts.with_options( + datetime_representation=DatetimeRepresentation.ISO8601) + self.assertEqual( + opts2.datetime_representation, DatetimeRepresentation.ISO8601) + + opts = json_util.JSONOptions(strict_number_long=True) + self.assertEqual(opts.strict_number_long, True) + opts2 = opts.with_options(strict_number_long=False) + self.assertEqual(opts2.strict_number_long, False) + + opts = json_util.CANONICAL_JSON_OPTIONS + self.assertNotEqual( + opts.uuid_representation, UuidRepresentation.JAVA_LEGACY) + opts2 = opts.with_options( + uuid_representation=UuidRepresentation.JAVA_LEGACY) + self.assertEqual( + opts2.uuid_representation, UuidRepresentation.JAVA_LEGACY) + self.assertEqual(opts2.document_class, dict) + opts3 = opts2.with_options(document_class=SON) + self.assertEqual( + opts3.uuid_representation, UuidRepresentation.JAVA_LEGACY) + self.assertEqual(opts3.document_class, SON) + def test_objectid(self): self.round_trip({"id": ObjectId()})