From 4ce002eef91aaf3a3693d6fea229cbce0b1555ae Mon Sep 17 00:00:00 2001 From: Ben Weintraub Date: Sun, 27 Oct 2024 20:17:10 -0700 Subject: [PATCH] fix: import cast for required const properties, since it's used in the template Resolves #1150 --- end_to_end_tests/baseline_openapi_3.0.json | 41 +++++ end_to_end_tests/baseline_openapi_3.1.yaml | 41 +++++ .../api/default/__init__.py | 12 +- .../get_models_oneof_with_required_const.py | 159 ++++++++++++++++++ .../my_test_api_client/models/__init__.py | 4 + ...with_required_const_response_200_type_0.py | 71 ++++++++ ...with_required_const_response_200_type_1.py | 71 ++++++++ .../parser/properties/const.py | 2 +- 8 files changed, 399 insertions(+), 2 deletions(-) create mode 100644 end_to_end_tests/golden-record/my_test_api_client/api/default/get_models_oneof_with_required_const.py create mode 100644 end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_0.py create mode 100644 end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_1.py diff --git a/end_to_end_tests/baseline_openapi_3.0.json b/end_to_end_tests/baseline_openapi_3.0.json index e5bbaf6fc..22a786a4f 100644 --- a/end_to_end_tests/baseline_openapi_3.0.json +++ b/end_to_end_tests/baseline_openapi_3.0.json @@ -1666,6 +1666,47 @@ } } } + }, + "/models/oneof-with-required-const": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "const": "alpha" + }, + "color": { + "type": "string" + } + }, + "required": ["type"] + }, + { + "type": "object", + "properties": { + "type": { + "const": "beta" + }, + "texture": { + "type": "string" + } + }, + "required": ["type"] + } + ] + } + } + } + } + } + } } }, "components": { diff --git a/end_to_end_tests/baseline_openapi_3.1.yaml b/end_to_end_tests/baseline_openapi_3.1.yaml index b6a6941e2..a19e46ce3 100644 --- a/end_to_end_tests/baseline_openapi_3.1.yaml +++ b/end_to_end_tests/baseline_openapi_3.1.yaml @@ -1657,6 +1657,47 @@ info: } } }, + "/models/oneof-with-required-const": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "const": "alpha" + }, + "color": { + "type": "string" + } + }, + "required": ["type"] + }, + { + "type": "object", + "properties": { + "type": { + "const": "beta" + }, + "texture": { + "type": "string" + } + }, + "required": ["type"] + } + ] + } + } + } + } + } + } + } } "components": "schemas": { diff --git a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py index 04d1162e8..0d7798e15 100644 --- a/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py +++ b/end_to_end_tests/custom-templates-golden-record/my_test_api_client/api/default/__init__.py @@ -2,7 +2,13 @@ import types -from . import get_common_parameters, get_models_allof, post_common_parameters, reserved_parameters +from . import ( + get_common_parameters, + get_models_allof, + get_models_oneof_with_required_const, + post_common_parameters, + reserved_parameters, +) class DefaultEndpoints: @@ -21,3 +27,7 @@ def reserved_parameters(cls) -> types.ModuleType: @classmethod def get_models_allof(cls) -> types.ModuleType: return get_models_allof + + @classmethod + def get_models_oneof_with_required_const(cls) -> types.ModuleType: + return get_models_oneof_with_required_const diff --git a/end_to_end_tests/golden-record/my_test_api_client/api/default/get_models_oneof_with_required_const.py b/end_to_end_tests/golden-record/my_test_api_client/api/default/get_models_oneof_with_required_const.py new file mode 100644 index 000000000..85be98c28 --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/api/default/get_models_oneof_with_required_const.py @@ -0,0 +1,159 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.get_models_oneof_with_required_const_response_200_type_0 import ( + GetModelsOneofWithRequiredConstResponse200Type0, +) +from ...models.get_models_oneof_with_required_const_response_200_type_1 import ( + GetModelsOneofWithRequiredConstResponse200Type1, +) +from ...types import Response + + +def _get_kwargs() -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": "/models/oneof-with-required-const", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[ + Union["GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1"] +]: + if response.status_code == 200: + + def _parse_response_200( + data: object, + ) -> Union[ + "GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1" + ]: + try: + if not isinstance(data, dict): + raise TypeError() + response_200_type_0 = GetModelsOneofWithRequiredConstResponse200Type0.from_dict(data) + + return response_200_type_0 + except: # noqa: E722 + pass + if not isinstance(data, dict): + raise TypeError() + response_200_type_1 = GetModelsOneofWithRequiredConstResponse200Type1.from_dict(data) + + return response_200_type_1 + + response_200 = _parse_response_200(response.json()) + + return response_200 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[ + Union["GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1"] +]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[ + Union["GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1"] +]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union['GetModelsOneofWithRequiredConstResponse200Type0', 'GetModelsOneofWithRequiredConstResponse200Type1']] + """ + + kwargs = _get_kwargs() + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[ + Union["GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1"] +]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union['GetModelsOneofWithRequiredConstResponse200Type0', 'GetModelsOneofWithRequiredConstResponse200Type1'] + """ + + return sync_detailed( + client=client, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[ + Union["GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1"] +]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union['GetModelsOneofWithRequiredConstResponse200Type0', 'GetModelsOneofWithRequiredConstResponse200Type1']] + """ + + kwargs = _get_kwargs() + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[ + Union["GetModelsOneofWithRequiredConstResponse200Type0", "GetModelsOneofWithRequiredConstResponse200Type1"] +]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union['GetModelsOneofWithRequiredConstResponse200Type0', 'GetModelsOneofWithRequiredConstResponse200Type1'] + """ + + return ( + await asyncio_detailed( + client=client, + ) + ).parsed diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py b/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py index 7a9c2ad32..f354c31c7 100644 --- a/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py +++ b/end_to_end_tests/golden-record/my_test_api_client/models/__init__.py @@ -39,6 +39,8 @@ from .get_location_header_types_int_enum_header import GetLocationHeaderTypesIntEnumHeader from .get_location_header_types_string_enum_header import GetLocationHeaderTypesStringEnumHeader from .get_models_allof_response_200 import GetModelsAllofResponse200 +from .get_models_oneof_with_required_const_response_200_type_0 import GetModelsOneofWithRequiredConstResponse200Type0 +from .get_models_oneof_with_required_const_response_200_type_1 import GetModelsOneofWithRequiredConstResponse200Type1 from .http_validation_error import HTTPValidationError from .import_ import Import from .json_like_body import JsonLikeBody @@ -121,6 +123,8 @@ "GetLocationHeaderTypesIntEnumHeader", "GetLocationHeaderTypesStringEnumHeader", "GetModelsAllofResponse200", + "GetModelsOneofWithRequiredConstResponse200Type0", + "GetModelsOneofWithRequiredConstResponse200Type1", "HTTPValidationError", "Import", "JsonLikeBody", diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_0.py b/end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_0.py new file mode 100644 index 000000000..972e1c765 --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_0.py @@ -0,0 +1,71 @@ +from typing import Any, Dict, List, Literal, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="GetModelsOneofWithRequiredConstResponse200Type0") + + +@_attrs_define +class GetModelsOneofWithRequiredConstResponse200Type0: + """ + Attributes: + type (Literal['alpha']): + color (Union[Unset, str]): + """ + + type: Literal["alpha"] + color: Union[Unset, str] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + type = self.type + + color = self.color + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "type": type, + } + ) + if color is not UNSET: + field_dict["color"] = color + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + type = cast(Literal["alpha"], d.pop("type")) + if type != "alpha": + raise ValueError(f"type must match const 'alpha', got '{type}'") + + color = d.pop("color", UNSET) + + get_models_oneof_with_required_const_response_200_type_0 = cls( + type=type, + color=color, + ) + + get_models_oneof_with_required_const_response_200_type_0.additional_properties = d + return get_models_oneof_with_required_const_response_200_type_0 + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_1.py b/end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_1.py new file mode 100644 index 000000000..4596c3cc4 --- /dev/null +++ b/end_to_end_tests/golden-record/my_test_api_client/models/get_models_oneof_with_required_const_response_200_type_1.py @@ -0,0 +1,71 @@ +from typing import Any, Dict, List, Literal, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="GetModelsOneofWithRequiredConstResponse200Type1") + + +@_attrs_define +class GetModelsOneofWithRequiredConstResponse200Type1: + """ + Attributes: + type (Literal['beta']): + texture (Union[Unset, str]): + """ + + type: Literal["beta"] + texture: Union[Unset, str] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + type = self.type + + texture = self.texture + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "type": type, + } + ) + if texture is not UNSET: + field_dict["texture"] = texture + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + type = cast(Literal["beta"], d.pop("type")) + if type != "beta": + raise ValueError(f"type must match const 'beta', got '{type}'") + + texture = d.pop("texture", UNSET) + + get_models_oneof_with_required_const_response_200_type_1 = cls( + type=type, + texture=texture, + ) + + get_models_oneof_with_required_const_response_200_type_1.additional_properties = d + return get_models_oneof_with_required_const_response_200_type_1 + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/openapi_python_client/parser/properties/const.py b/openapi_python_client/parser/properties/const.py index 9da3f8f1e..8b9967f19 100644 --- a/openapi_python_client/parser/properties/const.py +++ b/openapi_python_client/parser/properties/const.py @@ -111,7 +111,7 @@ def get_imports(self, *, prefix: str) -> set[str]: back to the root of the generated client. """ if self.required: - return {"from typing import Literal"} + return {"from typing import Literal, cast"} return { "from typing import Literal, Union, cast", f"from {prefix}types import UNSET, Unset",