From 30d2432628bf738d817d470779f2f84b4d2fc73d Mon Sep 17 00:00:00 2001 From: Dan Peachey Date: Fri, 2 Aug 2024 12:05:18 +0200 Subject: [PATCH 1/2] Remove allOf JSON schema workarounds --- pydantic/_internal/_generate_schema.py | 5 --- pydantic/json_schema.py | 55 ++++++-------------------- tests/test_computed_fields.py | 4 +- tests/test_edge_cases.py | 4 +- tests/test_forward_ref.py | 8 ++-- tests/test_generics.py | 2 +- tests/test_json.py | 2 +- tests/test_json_schema.py | 54 +++++++++++-------------- tests/test_serialize.py | 2 +- tests/test_type_alias_type.py | 6 +-- tests/test_types_self.py | 2 +- 11 files changed, 49 insertions(+), 95 deletions(-) diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index f566d2c88ef..c7408eb8fd2 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -265,11 +265,6 @@ def modify_model_json_schema( json_schema = handler(schema_or_field) original_schema = handler.resolve_ref_schema(json_schema) - # Preserve the fact that definitions schemas should never have sibling keys: - if '$ref' in original_schema: - ref = original_schema['$ref'] - original_schema.clear() - original_schema['allOf'] = [{'$ref': ref}] if title is not None: original_schema['title'] = title elif 'title' not in original_schema: diff --git a/pydantic/json_schema.py b/pydantic/json_schema.py index 5bf80a8fb05..4ca42feece0 100644 --- a/pydantic/json_schema.py +++ b/pydantic/json_schema.py @@ -414,18 +414,15 @@ def generate(self, schema: CoreSchema, mode: JsonSchemaMode = 'validation') -> J json_schema: JsonSchemaValue = self.generate_inner(schema) json_ref_counts = self.get_json_ref_counts(json_schema) - # Remove the top-level $ref if present; note that the _generate method already ensures there are no sibling keys ref = cast(JsonRef, json_schema.get('$ref')) while ref is not None: # may need to unpack multiple levels ref_json_schema = self.get_schema_from_definitions(ref) - if json_ref_counts[ref] > 1 or ref_json_schema is None: - # Keep the ref, but use an allOf to remove the top level $ref - json_schema = {'allOf': [{'$ref': ref}]} - else: - # "Unpack" the ref since this is the only reference + if json_ref_counts[ref] == 1 and ref_json_schema is not None and len(json_schema) == 1: + # "Unpack" the ref since this is the only reference and there are no sibling keys json_schema = ref_json_schema.copy() # copy to prevent recursive dict reference json_ref_counts[ref] -= 1 - ref = cast(JsonRef, json_schema.get('$ref')) + ref = cast(JsonRef, json_schema.get('$ref')) + ref = None self._garbage_collect_definitions(json_schema) definitions_remapping = self._build_definitions_remapping() @@ -478,15 +475,6 @@ def populate_defs(core_schema: CoreSchema, json_schema: JsonSchemaValue) -> Json json_schema = ref_json_schema return json_schema - def convert_to_all_of(json_schema: JsonSchemaValue) -> JsonSchemaValue: - if '$ref' in json_schema and len(json_schema.keys()) > 1: - # technically you can't have any other keys next to a "$ref" - # but it's an easy mistake to make and not hard to correct automatically here - json_schema = json_schema.copy() - ref = json_schema.pop('$ref') - json_schema = {'allOf': [{'$ref': ref}], **json_schema} - return json_schema - def handler_func(schema_or_field: CoreSchemaOrField) -> JsonSchemaValue: """Generate a JSON schema based on the input schema. @@ -512,7 +500,6 @@ def handler_func(schema_or_field: CoreSchemaOrField) -> JsonSchemaValue: raise TypeError(f'Unexpected schema type: schema={schema_or_field}') if _core_utils.is_core_schema(schema_or_field): json_schema = populate_defs(schema_or_field, json_schema) - json_schema = convert_to_all_of(json_schema) return json_schema current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, handler_func) @@ -545,7 +532,6 @@ def new_handler_func( json_schema = js_modify_function(schema_or_field, current_handler) if _core_utils.is_core_schema(schema_or_field): json_schema = populate_defs(schema_or_field, json_schema) - json_schema = convert_to_all_of(json_schema) return json_schema current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, new_handler_func) @@ -553,7 +539,6 @@ def new_handler_func( json_schema = current_handler(schema) if _core_utils.is_core_schema(schema): json_schema = populate_defs(schema, json_schema) - json_schema = convert_to_all_of(json_schema) return json_schema # ### Schema generation methods @@ -1080,12 +1065,8 @@ def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaVal # Return the inner schema, as though there was no default return json_schema - if '$ref' in json_schema: - # Since reference schemas do not support child keys, we wrap the reference schema in a single-case allOf: - return {'allOf': [json_schema], 'default': encoded_default} - else: - json_schema['default'] = encoded_default - return json_schema + json_schema['default'] = encoded_default + return json_schema def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a schema that allows null values. @@ -2001,14 +1982,13 @@ def get_cache_defs_ref_schema(self, core_ref: CoreRef) -> tuple[DefsRef, JsonSch return defs_ref, ref_json_schema def handle_ref_overrides(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: - """It is not valid for a schema with a top-level $ref to have sibling keys. + """Remove any sibling keys that are redundant with the referenced schema. - During our own schema generation, we treat sibling keys as overrides to the referenced schema, - but this is not how the official JSON schema spec works. + Args: + json_schema: The schema to remove redundant sibling keys from. - Because of this, we first remove any sibling keys that are redundant with the referenced schema, then if - any remain, we transform the schema from a top-level '$ref' to use allOf to move the $ref out of the top level. - (See bottom of https://swagger.io/docs/specification/using-ref/ for a reference about this behavior) + Returns: + The schema with redundant sibling keys removed. """ if '$ref' in json_schema: # prevent modifications to the input; this copy may be safe to drop if there is significant overhead @@ -2019,25 +1999,12 @@ def handle_ref_overrides(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: # This can happen when building schemas for models with not-yet-defined references. # It may be a good idea to do a recursive pass at the end of the generation to remove # any redundant override keys. - if len(json_schema) > 1: - # Make it an allOf to at least resolve the sibling keys issue - json_schema = json_schema.copy() - json_schema.setdefault('allOf', []) - json_schema['allOf'].append({'$ref': json_schema['$ref']}) - del json_schema['$ref'] - return json_schema for k, v in list(json_schema.items()): if k == '$ref': continue if k in referenced_json_schema and referenced_json_schema[k] == v: del json_schema[k] # redundant key - if len(json_schema) > 1: - # There is a remaining "override" key, so we need to move $ref out of the top level - json_ref = JsonRef(json_schema['$ref']) - del json_schema['$ref'] - assert 'allOf' not in json_schema # this should never happen, but just in case - json_schema['allOf'] = [{'$ref': json_ref}] return json_schema diff --git a/tests/test_computed_fields.py b/tests/test_computed_fields.py index 43664e194fc..12bb85f7931 100644 --- a/tests/test_computed_fields.py +++ b/tests/test_computed_fields.py @@ -731,8 +731,8 @@ def test_multiple_references_to_schema(model_factory: Callable[[], Any]) -> None assert ta.json_schema(mode='serialization') == { '$defs': {'CompModel': {'properties': {}, 'title': 'CompModel', 'type': 'object'}}, 'properties': { - 'comp_1': {'allOf': [{'$ref': '#/$defs/CompModel'}], 'readOnly': True}, - 'comp_2': {'allOf': [{'$ref': '#/$defs/CompModel'}], 'readOnly': True}, + 'comp_1': {'$ref': '#/$defs/CompModel', 'readOnly': True}, + 'comp_2': {'$ref': '#/$defs/CompModel', 'readOnly': True}, }, 'required': ['comp_1', 'comp_2'], 'title': 'Model', diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index a5b84cf70b2..31c1db7d274 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -2706,8 +2706,8 @@ class Outer(BaseModel): 'title': 'Model2', 'type': 'object', }, - 'Root1': {'allOf': [{'$ref': '#/$defs/Model1'}], 'title': 'Root1'}, - 'Root2': {'allOf': [{'$ref': '#/$defs/Model2'}], 'title': 'Root2'}, + 'Root1': {'$ref': '#/$defs/Model1', 'title': 'Root1'}, + 'Root2': {'$ref': '#/$defs/Model2', 'title': 'Root2'}, }, 'properties': { 'a': { diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 2062e33454d..b35cebfdae8 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -274,7 +274,7 @@ class Account(BaseModel): Account = module.Account assert Account.model_json_schema() == { - 'allOf': [{'$ref': '#/$defs/Account'}], + '$ref': '#/$defs/Account', '$defs': { 'Account': { 'title': 'Account', @@ -308,7 +308,7 @@ class Account(BaseModel): ) Account = module.Account assert Account.model_json_schema() == { - 'allOf': [{'$ref': '#/$defs/Account'}], + '$ref': '#/$defs/Account', '$defs': { 'Account': { 'title': 'Account', @@ -345,7 +345,7 @@ class Account(BaseModel): Account = module.Account assert Account.model_json_schema() == { - 'allOf': [{'$ref': '#/$defs/Account'}], + '$ref': '#/$defs/Account', '$defs': { 'Account': { 'title': 'Account', @@ -391,7 +391,7 @@ class Account(BaseModel): ) Account = module.Account assert Account.model_json_schema() == { - 'allOf': [{'$ref': '#/$defs/Account'}], + '$ref': '#/$defs/Account', '$defs': { 'Account': { 'title': 'Account', diff --git a/tests/test_generics.py b/tests/test_generics.py index 9c86370146a..2980af5d6c3 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -2163,7 +2163,7 @@ class Payload(BaseModel): 'type': 'object', } }, - 'properties': {'message': {'allOf': [{'$ref': '#/$defs/Payload'}], 'title': 'Message'}}, + 'properties': {'message': {'$ref': '#/$defs/Payload', 'title': 'Message'}}, 'required': ['message'], 'title': 'MessageWrapper[test_parse_generic_json..Payload]', 'type': 'object', diff --git a/tests/test_json.py b/tests/test_json.py index 7fbea70158b..f3b858434ed 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -392,7 +392,7 @@ def __get_pydantic_json_schema__( 'type': 'object', } }, - 'allOf': [{'$ref': '#/$defs/Model'}], + '$ref': '#/$defs/Model', } diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index f52d04a28d4..fe36b17ec24 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -271,7 +271,7 @@ class Model(BaseModel): 'properties': { 'foo': {'$ref': '#/$defs/FooEnum'}, 'bar': {'$ref': '#/$defs/BarEnum'}, - 'spam': {'allOf': [{'$ref': '#/$defs/SpamEnum'}], 'default': None}, + 'spam': {'$ref': '#/$defs/SpamEnum', 'default': None}, }, 'required': ['foo', 'bar'], 'title': 'Model', @@ -337,10 +337,10 @@ class Model(BaseModel): 'pikalias': { 'title': 'Pikapika!', 'description': 'Pika is definitely the best!', - 'allOf': [{'$ref': '#/$defs/FooBarEnum'}], + '$ref': '#/$defs/FooBarEnum', }, 'bulbialias': { - 'allOf': [{'$ref': '#/$defs/FooBarEnum'}], + '$ref': '#/$defs/FooBarEnum', 'default': 'foo', 'title': 'Bulbibulbi!', 'description': 'Bulbi is not...', @@ -384,13 +384,13 @@ class Foo(BaseModel): 'titled_enum': { 'title': 'Title of enum', 'description': 'Description of enum', - 'allOf': [{'$ref': '#/$defs/Names'}], + '$ref': '#/$defs/Names', }, 'model': {'$ref': '#/$defs/Pika'}, 'titled_model': { 'title': 'Title of model', 'description': 'Description of model', - 'allOf': [{'$ref': '#/$defs/Pika'}], + '$ref': '#/$defs/Pika', }, }, 'required': ['enum', 'titled_enum', 'model', 'titled_model'], @@ -427,7 +427,7 @@ class Foo(BaseModel): }, 'properties': { 'enum': {'$ref': '#/$defs/Names'}, - 'extra_enum': {'allOf': [{'$ref': '#/$defs/Names'}], 'extra': 'Extra field'}, + 'extra_enum': {'$ref': '#/$defs/Names', 'extra': 'Extra field'}, }, 'required': ['enum', 'extra_enum'], 'title': 'Foo', @@ -1432,7 +1432,7 @@ class Model(BaseModel): 'Bar': { 'title': 'Bar', 'type': 'object', - 'properties': {'b': {'allOf': [{'$ref': '#/$defs/Foo'}], 'default': {'a': 'foo'}}}, + 'properties': {'b': {'$ref': '#/$defs/Foo', 'default': {'a': 'foo'}}}, }, 'Baz': { 'title': 'Baz', @@ -1724,7 +1724,7 @@ class Outer(BaseModel): 'type': 'object', } }, - 'properties': {'inner': {'allOf': [{'$ref': '#/$defs/Inner'}], 'default': {'a': {'.': ''}}}}, + 'properties': {'inner': {'$ref': '#/$defs/Inner', 'default': {'a': {'.': ''}}}}, 'title': 'Outer', 'type': 'object', } @@ -3290,7 +3290,7 @@ class LocationBase(BaseModel): 'type': 'array', } }, - 'properties': {'coords': {'allOf': [{'$ref': '#/$defs/Coordinates'}], 'default': [34, 42]}}, + 'properties': {'coords': {'$ref': '#/$defs/Coordinates', 'default': [34, 42]}}, 'title': 'LocationBase', 'type': 'object', } @@ -3320,7 +3320,7 @@ class Location(BaseModel): 'type': 'object', } }, - 'properties': {'coords': {'allOf': [{'$ref': '#/$defs/CustomCoordinates'}], 'default': [34, 42]}}, + 'properties': {'coords': {'$ref': '#/$defs/CustomCoordinates', 'default': [34, 42]}}, 'title': 'Location', 'type': 'object', } @@ -3575,7 +3575,7 @@ def resolve(self) -> 'Model': ... 'required': ['uuid'], }, }, - 'allOf': [{'$ref': '#/$defs/Model'}], + '$ref': '#/$defs/Model', } @@ -4488,9 +4488,7 @@ class OuterModel(BaseModel): 'type': 'object', } }, - 'properties': { - 'nested_field': {'allOf': [{'$ref': '#/$defs/InnerModel'}], 'default': {'my_alias': 'foobar', 'foo': 'bar'}} - }, + 'properties': {'nested_field': {'$ref': '#/$defs/InnerModel', 'default': {'my_alias': 'foobar', 'foo': 'bar'}}}, 'title': 'OuterModel', 'type': 'object', } @@ -5074,7 +5072,7 @@ class B(RootModel[A]): assert B.model_json_schema() == { '$defs': {'A': {'description': 'A Model docstring', 'title': 'A', 'type': 'integer'}}, - 'allOf': [{'$ref': '#/$defs/A'}], + '$ref': '#/$defs/A', 'title': 'B', } @@ -5083,7 +5081,7 @@ class C(RootModel[A]): assert C.model_json_schema() == { '$defs': {'A': {'description': 'A Model docstring', 'title': 'A', 'type': 'integer'}}, - 'allOf': [{'$ref': '#/$defs/A'}], + '$ref': '#/$defs/A', 'title': 'C', 'description': 'C Model docstring', } @@ -5245,7 +5243,7 @@ def __get_pydantic_json_schema__( self, core_schema: CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: json_schema = handler(core_schema) - assert json_schema.keys() == {'allOf', 'examples'} + assert json_schema.keys() == {'$ref', 'examples'} json_schema['title'] = self.title return json_schema @@ -5267,7 +5265,7 @@ class Model(BaseModel): 'type': 'object', } }, - 'allOf': [{'$ref': '#/$defs/Model'}], + '$ref': '#/$defs/Model', 'examples': b'{"foo":{"name":"John","age":28}}', 'title': 'ModelTitle', } @@ -5312,7 +5310,7 @@ class Outer(BaseModel): 'type': 'object', } }, - 'properties': {'inner': {'allOf': [{'$ref': '#/$defs/Inner'}], 'title': 'Foo'}}, + 'properties': {'inner': {'$ref': '#/$defs/Inner', 'title': 'Foo'}}, 'required': ['inner'], 'title': 'Outer', 'type': 'object', @@ -5501,7 +5499,7 @@ def __get_pydantic_json_schema__( 'type': 'object', } }, - 'allOf': [{'$ref': '#/$defs/Model'}], + '$ref': '#/$defs/Model', 'title': 'Set from annotation', } @@ -5878,7 +5876,7 @@ class Bar(BaseModel): 'type': 'object', }, }, - 'allOf': [{'$ref': '#/$defs/Bar'}], + '$ref': '#/$defs/Bar', } @@ -6168,9 +6166,7 @@ class ModelParent(BaseModel): 'type': 'object', } }, - 'properties': { - 'parent': {'allOf': [{'$ref': '#/$defs/BuiltinDataclassParent'}], 'default': {'name': 'Jon Doe'}} - }, + 'properties': {'parent': {'$ref': '#/$defs/BuiltinDataclassParent', 'default': {'name': 'Jon Doe'}}}, 'title': 'child', 'type': 'object', }, @@ -6187,9 +6183,7 @@ class ModelParent(BaseModel): 'type': 'object', } }, - 'properties': { - 'parent': {'allOf': [{'$ref': '#/$defs/PydanticDataclassParent'}], 'default': {'name': 'Jon Doe'}} - }, + 'properties': {'parent': {'$ref': '#/$defs/PydanticDataclassParent', 'default': {'name': 'Jon Doe'}}}, 'title': 'child', 'type': 'object', }, @@ -6206,9 +6200,7 @@ class ModelParent(BaseModel): 'type': 'object', } }, - 'properties': { - 'parent': {'allOf': [{'$ref': '#/$defs/TypedDictParent'}], 'default': {'name': 'Jon Doe'}} - }, + 'properties': {'parent': {'$ref': '#/$defs/TypedDictParent', 'default': {'name': 'Jon Doe'}}}, 'title': 'child', 'type': 'object', }, @@ -6225,7 +6217,7 @@ class ModelParent(BaseModel): 'type': 'object', } }, - 'properties': {'parent': {'allOf': [{'$ref': '#/$defs/ModelParent'}], 'default': {'name': 'Jon Doe'}}}, + 'properties': {'parent': {'$ref': '#/$defs/ModelParent', 'default': {'name': 'Jon Doe'}}}, 'title': 'child', 'type': 'object', }, diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 7edf3be8e66..76d95d89c6b 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -964,7 +964,7 @@ class OtherModel(BaseModel): 'type': 'object', } }, - 'properties': {'x': {'allOf': [{'$ref': '#/$defs/OtherModel'}], 'title': 'X'}}, + 'properties': {'x': {'$ref': '#/$defs/OtherModel', 'title': 'X'}}, 'required': ['x'], 'title': 'Model', 'type': 'object', diff --git a/tests/test_type_alias_type.py b/tests/test_type_alias_type.py index 73e4b478f46..c4a1086b00d 100644 --- a/tests/test_type_alias_type.py +++ b/tests/test_type_alias_type.py @@ -117,7 +117,7 @@ def test_recursive_type_alias() -> None: ] assert t.json_schema() == { - 'allOf': [{'$ref': '#/$defs/JsonType'}], + '$ref': '#/$defs/JsonType', '$defs': { 'JsonType': { 'anyOf': [ @@ -233,7 +233,7 @@ def test_recursive_generic_type_alias() -> None: ] assert t.json_schema() == { - 'allOf': [{'$ref': '#/$defs/RecursiveGenericAlias_int_'}], + '$ref': '#/$defs/RecursiveGenericAlias_int_', '$defs': { 'RecursiveGenericAlias_int_': { 'type': 'array', @@ -324,7 +324,7 @@ def test_field() -> None: # insert_assert(ta.json_schema()) assert ta.json_schema() == { '$defs': {'SomeAlias': {'type': 'integer', 'description': 'number'}}, - 'allOf': [{'$ref': '#/$defs/SomeAlias'}], + '$ref': '#/$defs/SomeAlias', 'title': 'abc', } diff --git a/tests/test_types_self.py b/tests/test_types_self.py index e28ba64f9f4..7a052f0b050 100644 --- a/tests/test_types_self.py +++ b/tests/test_types_self.py @@ -135,7 +135,7 @@ class SelfRef(BaseModel): 'type': 'object', } }, - 'allOf': [{'$ref': '#/$defs/SelfRef'}], + '$ref': '#/$defs/SelfRef', } From 708f1ca786d4c8075ceeeffa33b45ee5552004a1 Mon Sep 17 00:00:00 2001 From: Dan Peachey Date: Wed, 7 Aug 2024 21:31:06 +0200 Subject: [PATCH 2/2] Deselect failing FastAPI test_openapi_schema test --- tests/test_fastapi.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fastapi.sh b/tests/test_fastapi.sh index 20754ac9d3f..6e8e7b050b3 100755 --- a/tests/test_fastapi.sh +++ b/tests/test_fastapi.sh @@ -21,5 +21,5 @@ cd .. && pip install . && cd fastapi # To skip a specific test, add '--deselect path/to/test.py::test_name' to the end of this command # # To update the list of deselected tests, remove all deselections, run the tests, and re-add any remaining failures -# TODO remove this once that test is fixed, see https://github.com/pydantic/pydantic/pull/9064 -./scripts/test.sh -vv --deselect tests/test_tutorial/test_path_params/test_tutorial005.py::test_get_enums_invalid +# TODO remove this once that test is fixed, see https://github.com/pydantic/pydantic/pull/10029 +./scripts/test.sh -vv --deselect tests/test_openapi_examples.py::test_openapi_schema