diff --git a/openapi_python_client/parser/properties/__init__.py b/openapi_python_client/parser/properties/__init__.py index 4b71df5a2..499b52aec 100644 --- a/openapi_python_client/parser/properties/__init__.py +++ b/openapi_python_client/parser/properties/__init__.py @@ -236,14 +236,27 @@ def build_model_property( required_properties: List[Property] = [] optional_properties: List[Property] = [] relative_imports: Set[str] = set() + references: List[oai.Reference] = [] class_name = data.title or name if parent_name: class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}" ref = Reference.from_ref(class_name) - for key, value in (data.properties or {}).items(): + all_props = data.properties or {} + if not isinstance(data, oai.Reference) and data.allOf: + for sub_prop in data.allOf: + if isinstance(sub_prop, oai.Reference): + references += [sub_prop] + else: + all_props.update(sub_prop.properties or {}) + required_set.update(sub_prop.required or []) + + for key, value in all_props.items(): prop_required = key in required_set + if not isinstance(value, oai.Reference) and value.allOf: + # resolved later + continue prop, schemas = property_from_data( name=key, required=prop_required, data=value, schemas=schemas, parent_name=class_name ) @@ -257,6 +270,7 @@ def build_model_property( prop = ModelProperty( reference=ref, + references=references, required_properties=required_properties, optional_properties=optional_properties, relative_imports=relative_imports, @@ -508,6 +522,16 @@ def build_schemas(*, components: Dict[str, Union[oai.Reference, oai.Schema]]) -> schemas = schemas_or_err processing = True # We made some progress this round, do another after it's done to_process = next_round - schemas.errors.extend(errors) + resolve_errors: List[PropertyError] = [] + models = list(schemas.models.values()) + for model in models: + schemas_or_err = model.resolve_references(components=components, schemas=schemas) + if isinstance(schemas_or_err, PropertyError): + resolve_errors.append(schemas_or_err) + else: + schemas = schemas_or_err + + schemas.errors.extend(errors) + schemas.errors.extend(resolve_errors) return schemas diff --git a/openapi_python_client/parser/properties/model_property.py b/openapi_python_client/parser/properties/model_property.py index ca36171af..fbe257111 100644 --- a/openapi_python_client/parser/properties/model_property.py +++ b/openapi_python_client/parser/properties/model_property.py @@ -1,17 +1,25 @@ -from typing import ClassVar, List, Set +from __future__ import annotations + +from collections.abc import Iterable +from typing import TYPE_CHECKING, ClassVar, Dict, List, Set, Union import attr +from ... import schema as oai +from ..errors import PropertyError from ..reference import Reference from .property import Property +if TYPE_CHECKING: + from .schemas import Schemas + @attr.s(auto_attribs=True, frozen=True) class ModelProperty(Property): """ A property which refers to another Schema """ reference: Reference - + references: List[oai.Reference] required_properties: List[Property] optional_properties: List[Property] description: str @@ -19,6 +27,49 @@ class ModelProperty(Property): template: ClassVar[str] = "model_property.pyi" + def resolve_references( + self, components: Dict[str, Union[oai.Reference, oai.Schema]], schemas: Schemas + ) -> Union[Schemas, PropertyError]: + from ..properties import property_from_data + + required_set = set() + props = {} + while self.references: + reference = self.references.pop() + source_name = Reference.from_ref(reference.ref).class_name + referenced_prop = components[source_name] + assert isinstance(referenced_prop, oai.Schema) + for p, val in (referenced_prop.properties or {}).items(): + props[p] = (val, source_name) + for sub_prop in referenced_prop.allOf or []: + if isinstance(sub_prop, oai.Reference): + self.references.append(sub_prop) + else: + for p, val in (sub_prop.properties or {}).items(): + props[p] = (val, source_name) + if isinstance(referenced_prop.required, Iterable): + for sub_prop_name in referenced_prop.required: + required_set.add(sub_prop_name) + + for key, (value, source_name) in (props or {}).items(): + required = key in required_set + prop, schemas = property_from_data( + name=key, required=required, data=value, schemas=schemas, parent_name=source_name + ) + if isinstance(prop, PropertyError): + return prop + if required: + self.required_properties.append(prop) + # Remove the optional version + new_optional_props = [op for op in self.optional_properties if op.name != prop.name] + self.optional_properties.clear() + self.optional_properties.extend(new_optional_props) + elif not any(ep for ep in (self.optional_properties + self.required_properties) if ep.name == prop.name): + self.optional_properties.append(prop) + self.relative_imports.update(prop.get_imports(prefix="..")) + + return schemas + def get_type_string(self, no_optional: bool = False) -> str: """ Get a string representation of type that should be used when declaring this property """ type_string = self.reference.class_name diff --git a/tests/test_parser/test_properties/test_init.py b/tests/test_parser/test_properties/test_init.py index 96047c60f..39af21b35 100644 --- a/tests/test_parser/test_properties/test_init.py +++ b/tests/test_parser/test_properties/test_init.py @@ -584,6 +584,7 @@ def test_property_from_data_ref_model(self): nullable=False, default=None, reference=Reference(class_name=class_name, module_name="my_model"), + references=[], required_properties=[], optional_properties=[], description="", @@ -599,6 +600,7 @@ def test_property_from_data_ref_model(self): nullable=False, default=None, reference=Reference(class_name=class_name, module_name="my_model"), + references=[], required_properties=[], optional_properties=[], description="", @@ -984,19 +986,25 @@ def test__string_based_property_unsupported_format(self, mocker): def test_build_schemas(mocker): build_model_property = mocker.patch(f"{MODULE_NAME}.build_model_property") in_data = {"1": mocker.MagicMock(enum=None), "2": mocker.MagicMock(enum=None), "3": mocker.MagicMock(enum=None)} + model_1 = mocker.MagicMock() schemas_1 = mocker.MagicMock() model_2 = mocker.MagicMock() schemas_2 = mocker.MagicMock(errors=[]) - error = PropertyError() + schemas_2.models = {"1": model_1, "2": model_2} + error_1 = PropertyError() schemas_3 = mocker.MagicMock() + schemas_4 = mocker.MagicMock(errors=[]) + model_1.resolve_references.return_value = schemas_4 + error_2 = PropertyError() + model_2.resolve_references.return_value = error_2 # This loops through one for each, then again to retry the error build_model_property.side_effect = [ (model_1, schemas_1), (model_2, schemas_2), - (error, schemas_3), - (error, schemas_3), + (error_1, schemas_3), + (error_1, schemas_3), ] from openapi_python_client.parser.properties import Schemas, build_schemas @@ -1012,8 +1020,12 @@ def test_build_schemas(mocker): ] ) # schemas_3 was the last to come back from build_model_property, but it should be ignored because it's an error - assert result == schemas_2 - assert result.errors == [error] + model_1.resolve_references.assert_called_once_with(components=in_data, schemas=schemas_2) + # schemas_4 came from resolving model_1 + model_2.resolve_references.assert_called_once_with(components=in_data, schemas=schemas_4) + # resolving model_2 resulted in err, so no schemas_5 + assert result == schemas_4 + assert result.errors == [error_1, error_2] def test_build_parse_error_on_reference(): @@ -1073,6 +1085,7 @@ def test_build_model_property(): nullable=False, default=None, reference=Reference(class_name="ParentMyModel", module_name="parent_my_model"), + references=[], required_properties=[StringProperty(name="req", required=True, nullable=False, default=None)], optional_properties=[DateTimeProperty(name="opt", required=False, nullable=False, default=None)], description=data.description, diff --git a/tests/test_parser/test_properties/test_model_property.py b/tests/test_parser/test_properties/test_model_property.py index 72c8f27f1..90395876a 100644 --- a/tests/test_parser/test_properties/test_model_property.py +++ b/tests/test_parser/test_properties/test_model_property.py @@ -23,6 +23,7 @@ def test_get_type_string(no_optional, nullable, required, expected): nullable=nullable, default=None, reference=Reference(class_name="MyClass", module_name="my_module"), + references=[], description="", optional_properties=[], required_properties=[], @@ -41,6 +42,7 @@ def test_get_imports(): nullable=True, default=None, reference=Reference(class_name="MyClass", module_name="my_module"), + references=[], description="", optional_properties=[], required_properties=[], @@ -55,3 +57,67 @@ def test_get_imports(): "from typing import Dict", "from typing import cast", } + + +def test_resolve_references(mocker): + import openapi_python_client.schema as oai + from openapi_python_client.parser.properties import build_model_property + + schemas = { + "RefA": oai.Schema.construct( + title=mocker.MagicMock(), + description=mocker.MagicMock(), + required=["String"], + properties={ + "String": oai.Schema.construct(type="string"), + "Enum": oai.Schema.construct(type="string", enum=["aValue"]), + "DateTime": oai.Schema.construct(type="string", format="date-time"), + }, + ), + "RefB": oai.Schema.construct( + title=mocker.MagicMock(), + description=mocker.MagicMock(), + required=["DateTime"], + properties={ + "Int": oai.Schema.construct(type="integer"), + "DateTime": oai.Schema.construct(type="string", format="date-time"), + "Float": oai.Schema.construct(type="number", format="float"), + }, + ), + # Intentionally no properties defined + "RefC": oai.Schema.construct( + title=mocker.MagicMock(), + description=mocker.MagicMock(), + ), + } + + model_schema = oai.Schema.construct( + allOf=[ + oai.Reference.construct(ref="#/components/schemas/RefA"), + oai.Reference.construct(ref="#/components/schemas/RefB"), + oai.Reference.construct(ref="#/components/schemas/RefC"), + oai.Schema.construct( + title=mocker.MagicMock(), + description=mocker.MagicMock(), + required=["Float"], + properties={ + "String": oai.Schema.construct(type="string"), + "Float": oai.Schema.construct(type="number", format="float"), + }, + ), + ] + ) + + components = {**schemas, "Model": model_schema} + + from openapi_python_client.parser.properties import Schemas + + schemas_holder = Schemas() + model, schemas_holder = build_model_property( + data=model_schema, name="Model", required=True, schemas=schemas_holder, parent_name=None + ) + model.resolve_references(components, schemas_holder) + assert sorted(p.name for p in model.required_properties) == ["DateTime", "Float", "String"] + assert all(p.required for p in model.required_properties) + assert sorted(p.name for p in model.optional_properties) == ["Enum", "Int"] + assert all(not p.required for p in model.optional_properties)