Skip to content

Remove allOf JSON schema workarounds #10029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,6 @@ def modify_model_json_schema(

json_schema = handler(schema_or_field)
original_schema = handler.resolve_ref_schema(json_schema)
# Preserve the fact that definitions schemas should never have sibling keys:
if '$ref' in original_schema:
ref = original_schema['$ref']
original_schema.clear()
original_schema['allOf'] = [{'$ref': ref}]
if title is not None:
original_schema['title'] = title
elif 'title' not in original_schema:
Expand Down
55 changes: 11 additions & 44 deletions pydantic/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,18 +414,15 @@ def generate(self, schema: CoreSchema, mode: JsonSchemaMode = 'validation') -> J
json_schema: JsonSchemaValue = self.generate_inner(schema)
json_ref_counts = self.get_json_ref_counts(json_schema)

# Remove the top-level $ref if present; note that the _generate method already ensures there are no sibling keys
ref = cast(JsonRef, json_schema.get('$ref'))
while ref is not None: # may need to unpack multiple levels
ref_json_schema = self.get_schema_from_definitions(ref)
if json_ref_counts[ref] > 1 or ref_json_schema is None:
# Keep the ref, but use an allOf to remove the top level $ref
json_schema = {'allOf': [{'$ref': ref}]}
else:
# "Unpack" the ref since this is the only reference
if json_ref_counts[ref] == 1 and ref_json_schema is not None and len(json_schema) == 1:
# "Unpack" the ref since this is the only reference and there are no sibling keys
json_schema = ref_json_schema.copy() # copy to prevent recursive dict reference
json_ref_counts[ref] -= 1
ref = cast(JsonRef, json_schema.get('$ref'))
ref = cast(JsonRef, json_schema.get('$ref'))
ref = None

self._garbage_collect_definitions(json_schema)
definitions_remapping = self._build_definitions_remapping()
Expand Down Expand Up @@ -478,15 +475,6 @@ def populate_defs(core_schema: CoreSchema, json_schema: JsonSchemaValue) -> Json
json_schema = ref_json_schema
return json_schema

def convert_to_all_of(json_schema: JsonSchemaValue) -> JsonSchemaValue:
if '$ref' in json_schema and len(json_schema.keys()) > 1:
# technically you can't have any other keys next to a "$ref"
# but it's an easy mistake to make and not hard to correct automatically here
json_schema = json_schema.copy()
ref = json_schema.pop('$ref')
json_schema = {'allOf': [{'$ref': ref}], **json_schema}
return json_schema

def handler_func(schema_or_field: CoreSchemaOrField) -> JsonSchemaValue:
"""Generate a JSON schema based on the input schema.

Expand All @@ -512,7 +500,6 @@ def handler_func(schema_or_field: CoreSchemaOrField) -> JsonSchemaValue:
raise TypeError(f'Unexpected schema type: schema={schema_or_field}')
if _core_utils.is_core_schema(schema_or_field):
json_schema = populate_defs(schema_or_field, json_schema)
json_schema = convert_to_all_of(json_schema)
return json_schema

current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, handler_func)
Expand Down Expand Up @@ -545,15 +532,13 @@ def new_handler_func(
json_schema = js_modify_function(schema_or_field, current_handler)
if _core_utils.is_core_schema(schema_or_field):
json_schema = populate_defs(schema_or_field, json_schema)
json_schema = convert_to_all_of(json_schema)
return json_schema

current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, new_handler_func)

json_schema = current_handler(schema)
if _core_utils.is_core_schema(schema):
json_schema = populate_defs(schema, json_schema)
json_schema = convert_to_all_of(json_schema)
return json_schema

# ### Schema generation methods
Expand Down Expand Up @@ -1080,12 +1065,8 @@ def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaVal
# Return the inner schema, as though there was no default
return json_schema

if '$ref' in json_schema:
# Since reference schemas do not support child keys, we wrap the reference schema in a single-case allOf:
return {'allOf': [json_schema], 'default': encoded_default}
else:
json_schema['default'] = encoded_default
return json_schema
json_schema['default'] = encoded_default
return json_schema

def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue:
"""Generates a JSON schema that matches a schema that allows null values.
Expand Down Expand Up @@ -2001,14 +1982,13 @@ def get_cache_defs_ref_schema(self, core_ref: CoreRef) -> tuple[DefsRef, JsonSch
return defs_ref, ref_json_schema

def handle_ref_overrides(self, json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""It is not valid for a schema with a top-level $ref to have sibling keys.
"""Remove any sibling keys that are redundant with the referenced schema.

During our own schema generation, we treat sibling keys as overrides to the referenced schema,
but this is not how the official JSON schema spec works.
Args:
json_schema: The schema to remove redundant sibling keys from.

Because of this, we first remove any sibling keys that are redundant with the referenced schema, then if
any remain, we transform the schema from a top-level '$ref' to use allOf to move the $ref out of the top level.
(See bottom of https://swagger.io/docs/specification/using-ref/ for a reference about this behavior)
Returns:
The schema with redundant sibling keys removed.
"""
if '$ref' in json_schema:
# prevent modifications to the input; this copy may be safe to drop if there is significant overhead
Expand All @@ -2019,25 +1999,12 @@ def handle_ref_overrides(self, json_schema: JsonSchemaValue) -> JsonSchemaValue:
# This can happen when building schemas for models with not-yet-defined references.
# It may be a good idea to do a recursive pass at the end of the generation to remove
# any redundant override keys.
if len(json_schema) > 1:
# Make it an allOf to at least resolve the sibling keys issue
json_schema = json_schema.copy()
json_schema.setdefault('allOf', [])
json_schema['allOf'].append({'$ref': json_schema['$ref']})
del json_schema['$ref']

return json_schema
for k, v in list(json_schema.items()):
if k == '$ref':
continue
if k in referenced_json_schema and referenced_json_schema[k] == v:
del json_schema[k] # redundant key
if len(json_schema) > 1:
# There is a remaining "override" key, so we need to move $ref out of the top level
json_ref = JsonRef(json_schema['$ref'])
del json_schema['$ref']
assert 'allOf' not in json_schema # this should never happen, but just in case
json_schema['allOf'] = [{'$ref': json_ref}]

return json_schema

Expand Down
4 changes: 2 additions & 2 deletions tests/test_computed_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ def test_multiple_references_to_schema(model_factory: Callable[[], Any]) -> None
assert ta.json_schema(mode='serialization') == {
'$defs': {'CompModel': {'properties': {}, 'title': 'CompModel', 'type': 'object'}},
'properties': {
'comp_1': {'allOf': [{'$ref': '#/$defs/CompModel'}], 'readOnly': True},
'comp_2': {'allOf': [{'$ref': '#/$defs/CompModel'}], 'readOnly': True},
'comp_1': {'$ref': '#/$defs/CompModel', 'readOnly': True},
'comp_2': {'$ref': '#/$defs/CompModel', 'readOnly': True},
},
'required': ['comp_1', 'comp_2'],
'title': 'Model',
Expand Down
4 changes: 2 additions & 2 deletions tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2706,8 +2706,8 @@ class Outer(BaseModel):
'title': 'Model2',
'type': 'object',
},
'Root1': {'allOf': [{'$ref': '#/$defs/Model1'}], 'title': 'Root1'},
'Root2': {'allOf': [{'$ref': '#/$defs/Model2'}], 'title': 'Root2'},
'Root1': {'$ref': '#/$defs/Model1', 'title': 'Root1'},
'Root2': {'$ref': '#/$defs/Model2', 'title': 'Root2'},
},
'properties': {
'a': {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fastapi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ cd .. && pip install . && cd fastapi
# To skip a specific test, add '--deselect path/to/test.py::test_name' to the end of this command
#
# To update the list of deselected tests, remove all deselections, run the tests, and re-add any remaining failures
# TODO remove this once that test is fixed, see https://github.com/pydantic/pydantic/pull/9064
./scripts/test.sh -vv --deselect tests/test_tutorial/test_path_params/test_tutorial005.py::test_get_enums_invalid
# TODO remove this once that test is fixed, see https://github.com/pydantic/pydantic/pull/10029
./scripts/test.sh -vv --deselect tests/test_openapi_examples.py::test_openapi_schema
8 changes: 4 additions & 4 deletions tests/test_forward_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class Account(BaseModel):

Account = module.Account
assert Account.model_json_schema() == {
'allOf': [{'$ref': '#/$defs/Account'}],
'$ref': '#/$defs/Account',
'$defs': {
'Account': {
'title': 'Account',
Expand Down Expand Up @@ -308,7 +308,7 @@ class Account(BaseModel):
)
Account = module.Account
assert Account.model_json_schema() == {
'allOf': [{'$ref': '#/$defs/Account'}],
'$ref': '#/$defs/Account',
'$defs': {
'Account': {
'title': 'Account',
Expand Down Expand Up @@ -345,7 +345,7 @@ class Account(BaseModel):

Account = module.Account
assert Account.model_json_schema() == {
'allOf': [{'$ref': '#/$defs/Account'}],
'$ref': '#/$defs/Account',
'$defs': {
'Account': {
'title': 'Account',
Expand Down Expand Up @@ -391,7 +391,7 @@ class Account(BaseModel):
)
Account = module.Account
assert Account.model_json_schema() == {
'allOf': [{'$ref': '#/$defs/Account'}],
'$ref': '#/$defs/Account',
'$defs': {
'Account': {
'title': 'Account',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2163,7 +2163,7 @@ class Payload(BaseModel):
'type': 'object',
}
},
'properties': {'message': {'allOf': [{'$ref': '#/$defs/Payload'}], 'title': 'Message'}},
'properties': {'message': {'$ref': '#/$defs/Payload', 'title': 'Message'}},
'required': ['message'],
'title': 'MessageWrapper[test_parse_generic_json.<locals>.Payload]',
'type': 'object',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def __get_pydantic_json_schema__(
'type': 'object',
}
},
'allOf': [{'$ref': '#/$defs/Model'}],
'$ref': '#/$defs/Model',
}


Expand Down
Loading
Loading