Skip to content

Commit 7eb4093

Browse files
codebutlerdbanty
andauthored
bad code generated for nested unions (#959)
fixes #958 --------- Co-authored-by: Dylan Anthony <[email protected]> Co-authored-by: Dylan Anthony <[email protected]>
1 parent f1d76ab commit 7eb4093

File tree

8 files changed

+342
-1
lines changed

8 files changed

+342
-1
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
default: patch
3+
---
4+
5+
# Fix invalid type check for nested unions
6+
7+
Nested union types (unions of unions) were generating `isinstance()` checks that were not valid (at least for Python 3.9).
8+
9+
Thanks to @codebutler for PR #959 which fixes #958 and #967.

end_to_end_tests/baseline_openapi_3.0.json

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@
801801
}
802802
}
803803
}
804-
},
804+
},
805805
"/enum/int": {
806806
"post": {
807807
"tags": [
@@ -2531,6 +2531,53 @@
25312531
"ModelWithBackslashInDescription": {
25322532
"type": "object",
25332533
"description": "Description with special character: \\"
2534+
},
2535+
"ModelWithDiscriminatedUnion": {
2536+
"type": "object",
2537+
"properties": {
2538+
"discriminated_union": {
2539+
"allOf": [
2540+
{
2541+
"$ref": "#/components/schemas/ADiscriminatedUnion"
2542+
}
2543+
],
2544+
"nullable": true
2545+
}
2546+
}
2547+
},
2548+
"ADiscriminatedUnion": {
2549+
"type": "object",
2550+
"discriminator": {
2551+
"propertyName": "modelType",
2552+
"mapping": {
2553+
"type1": "#/components/schemas/ADiscriminatedUnionType1",
2554+
"type2": "#/components/schemas/ADiscriminatedUnionType2"
2555+
}
2556+
},
2557+
"oneOf": [
2558+
{
2559+
"$ref": "#/components/schemas/ADiscriminatedUnionType1"
2560+
},
2561+
{
2562+
"$ref": "#/components/schemas/ADiscriminatedUnionType2"
2563+
}
2564+
]
2565+
},
2566+
"ADiscriminatedUnionType1": {
2567+
"type": "object",
2568+
"properties": {
2569+
"modelType": {
2570+
"type": "string"
2571+
}
2572+
}
2573+
},
2574+
"ADiscriminatedUnionType2": {
2575+
"type": "object",
2576+
"properties": {
2577+
"modelType": {
2578+
"type": "string"
2579+
}
2580+
}
25342581
}
25352582
},
25362583
"parameters": {

end_to_end_tests/baseline_openapi_3.1.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,6 +2543,55 @@ info:
25432543
"ModelWithBackslashInDescription": {
25442544
"type": "object",
25452545
"description": "Description with special character: \\"
2546+
},
2547+
"ModelWithDiscriminatedUnion": {
2548+
"type": "object",
2549+
"properties": {
2550+
"discriminated_union": {
2551+
"oneOf": [
2552+
{
2553+
"$ref": "#/components/schemas/ADiscriminatedUnion"
2554+
},
2555+
{
2556+
"type": "null"
2557+
}
2558+
],
2559+
}
2560+
}
2561+
},
2562+
"ADiscriminatedUnion": {
2563+
"type": "object",
2564+
"discriminator": {
2565+
"propertyName": "modelType",
2566+
"mapping": {
2567+
"type1": "#/components/schemas/ADiscriminatedUnionType1",
2568+
"type2": "#/components/schemas/ADiscriminatedUnionType2"
2569+
}
2570+
},
2571+
"oneOf": [
2572+
{
2573+
"$ref": "#/components/schemas/ADiscriminatedUnionType1"
2574+
},
2575+
{
2576+
"$ref": "#/components/schemas/ADiscriminatedUnionType2"
2577+
}
2578+
]
2579+
},
2580+
"ADiscriminatedUnionType1": {
2581+
"type": "object",
2582+
"properties": {
2583+
"modelType": {
2584+
"type": "string"
2585+
}
2586+
}
2587+
},
2588+
"ADiscriminatedUnionType2": {
2589+
"type": "object",
2590+
"properties": {
2591+
"modelType": {
2592+
"type": "string"
2593+
}
2594+
}
25462595
}
25472596
},
25482597
"parameters": {

end_to_end_tests/golden-record/my_test_api_client/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
""" Contains all the data models used in inputs/outputs """
22

3+
from .a_discriminated_union_type_1 import ADiscriminatedUnionType1
4+
from .a_discriminated_union_type_2 import ADiscriminatedUnionType2
35
from .a_form_data import AFormData
46
from .a_model import AModel
57
from .a_model_with_properties_reference_that_are_not_object import AModelWithPropertiesReferenceThatAreNotObject
@@ -54,6 +56,7 @@
5456
from .model_with_circular_ref_in_additional_properties_a import ModelWithCircularRefInAdditionalPropertiesA
5557
from .model_with_circular_ref_in_additional_properties_b import ModelWithCircularRefInAdditionalPropertiesB
5658
from .model_with_date_time_property import ModelWithDateTimeProperty
59+
from .model_with_discriminated_union import ModelWithDiscriminatedUnion
5760
from .model_with_primitive_additional_properties import ModelWithPrimitiveAdditionalProperties
5861
from .model_with_primitive_additional_properties_a_date_holder import ModelWithPrimitiveAdditionalPropertiesADateHolder
5962
from .model_with_property_ref import ModelWithPropertyRef
@@ -79,6 +82,8 @@
7982
from .validation_error import ValidationError
8083

8184
__all__ = (
85+
"ADiscriminatedUnionType1",
86+
"ADiscriminatedUnionType2",
8287
"AFormData",
8388
"AllOfHasPropertiesButNoType",
8489
"AllOfHasPropertiesButNoTypeTypeEnum",
@@ -125,6 +130,7 @@
125130
"ModelWithCircularRefInAdditionalPropertiesA",
126131
"ModelWithCircularRefInAdditionalPropertiesB",
127132
"ModelWithDateTimeProperty",
133+
"ModelWithDiscriminatedUnion",
128134
"ModelWithPrimitiveAdditionalProperties",
129135
"ModelWithPrimitiveAdditionalPropertiesADateHolder",
130136
"ModelWithPropertyRef",
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Any, Dict, List, Type, TypeVar, Union
2+
3+
from attrs import define as _attrs_define
4+
from attrs import field as _attrs_field
5+
6+
from ..types import UNSET, Unset
7+
8+
T = TypeVar("T", bound="ADiscriminatedUnionType1")
9+
10+
11+
@_attrs_define
12+
class ADiscriminatedUnionType1:
13+
"""
14+
Attributes:
15+
model_type (Union[Unset, str]):
16+
"""
17+
18+
model_type: Union[Unset, str] = UNSET
19+
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
20+
21+
def to_dict(self) -> Dict[str, Any]:
22+
model_type = self.model_type
23+
24+
field_dict: Dict[str, Any] = {}
25+
field_dict.update(self.additional_properties)
26+
field_dict.update({})
27+
if model_type is not UNSET:
28+
field_dict["modelType"] = model_type
29+
30+
return field_dict
31+
32+
@classmethod
33+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
34+
d = src_dict.copy()
35+
model_type = d.pop("modelType", UNSET)
36+
37+
a_discriminated_union_type_1 = cls(
38+
model_type=model_type,
39+
)
40+
41+
a_discriminated_union_type_1.additional_properties = d
42+
return a_discriminated_union_type_1
43+
44+
@property
45+
def additional_keys(self) -> List[str]:
46+
return list(self.additional_properties.keys())
47+
48+
def __getitem__(self, key: str) -> Any:
49+
return self.additional_properties[key]
50+
51+
def __setitem__(self, key: str, value: Any) -> None:
52+
self.additional_properties[key] = value
53+
54+
def __delitem__(self, key: str) -> None:
55+
del self.additional_properties[key]
56+
57+
def __contains__(self, key: str) -> bool:
58+
return key in self.additional_properties
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Any, Dict, List, Type, TypeVar, Union
2+
3+
from attrs import define as _attrs_define
4+
from attrs import field as _attrs_field
5+
6+
from ..types import UNSET, Unset
7+
8+
T = TypeVar("T", bound="ADiscriminatedUnionType2")
9+
10+
11+
@_attrs_define
12+
class ADiscriminatedUnionType2:
13+
"""
14+
Attributes:
15+
model_type (Union[Unset, str]):
16+
"""
17+
18+
model_type: Union[Unset, str] = UNSET
19+
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
20+
21+
def to_dict(self) -> Dict[str, Any]:
22+
model_type = self.model_type
23+
24+
field_dict: Dict[str, Any] = {}
25+
field_dict.update(self.additional_properties)
26+
field_dict.update({})
27+
if model_type is not UNSET:
28+
field_dict["modelType"] = model_type
29+
30+
return field_dict
31+
32+
@classmethod
33+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
34+
d = src_dict.copy()
35+
model_type = d.pop("modelType", UNSET)
36+
37+
a_discriminated_union_type_2 = cls(
38+
model_type=model_type,
39+
)
40+
41+
a_discriminated_union_type_2.additional_properties = d
42+
return a_discriminated_union_type_2
43+
44+
@property
45+
def additional_keys(self) -> List[str]:
46+
return list(self.additional_properties.keys())
47+
48+
def __getitem__(self, key: str) -> Any:
49+
return self.additional_properties[key]
50+
51+
def __setitem__(self, key: str, value: Any) -> None:
52+
self.additional_properties[key] = value
53+
54+
def __delitem__(self, key: str) -> None:
55+
del self.additional_properties[key]
56+
57+
def __contains__(self, key: str) -> bool:
58+
return key in self.additional_properties
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, cast
2+
3+
from attrs import define as _attrs_define
4+
from attrs import field as _attrs_field
5+
6+
from ..types import UNSET, Unset
7+
8+
if TYPE_CHECKING:
9+
from ..models.a_discriminated_union_type_1 import ADiscriminatedUnionType1
10+
from ..models.a_discriminated_union_type_2 import ADiscriminatedUnionType2
11+
12+
13+
T = TypeVar("T", bound="ModelWithDiscriminatedUnion")
14+
15+
16+
@_attrs_define
17+
class ModelWithDiscriminatedUnion:
18+
"""
19+
Attributes:
20+
discriminated_union (Union['ADiscriminatedUnionType1', 'ADiscriminatedUnionType2', None, Unset]):
21+
"""
22+
23+
discriminated_union: Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset] = UNSET
24+
additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict)
25+
26+
def to_dict(self) -> Dict[str, Any]:
27+
from ..models.a_discriminated_union_type_1 import ADiscriminatedUnionType1
28+
from ..models.a_discriminated_union_type_2 import ADiscriminatedUnionType2
29+
30+
discriminated_union: Union[Dict[str, Any], None, Unset]
31+
if isinstance(self.discriminated_union, Unset):
32+
discriminated_union = UNSET
33+
elif isinstance(self.discriminated_union, ADiscriminatedUnionType1):
34+
discriminated_union = self.discriminated_union.to_dict()
35+
elif isinstance(self.discriminated_union, ADiscriminatedUnionType2):
36+
discriminated_union = self.discriminated_union.to_dict()
37+
else:
38+
discriminated_union = self.discriminated_union
39+
40+
field_dict: Dict[str, Any] = {}
41+
field_dict.update(self.additional_properties)
42+
field_dict.update({})
43+
if discriminated_union is not UNSET:
44+
field_dict["discriminated_union"] = discriminated_union
45+
46+
return field_dict
47+
48+
@classmethod
49+
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
50+
from ..models.a_discriminated_union_type_1 import ADiscriminatedUnionType1
51+
from ..models.a_discriminated_union_type_2 import ADiscriminatedUnionType2
52+
53+
d = src_dict.copy()
54+
55+
def _parse_discriminated_union(
56+
data: object,
57+
) -> Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset]:
58+
if data is None:
59+
return data
60+
if isinstance(data, Unset):
61+
return data
62+
try:
63+
if not isinstance(data, dict):
64+
raise TypeError()
65+
componentsschemas_a_discriminated_union_type_0 = ADiscriminatedUnionType1.from_dict(data)
66+
67+
return componentsschemas_a_discriminated_union_type_0
68+
except: # noqa: E722
69+
pass
70+
try:
71+
if not isinstance(data, dict):
72+
raise TypeError()
73+
componentsschemas_a_discriminated_union_type_1 = ADiscriminatedUnionType2.from_dict(data)
74+
75+
return componentsschemas_a_discriminated_union_type_1
76+
except: # noqa: E722
77+
pass
78+
return cast(Union["ADiscriminatedUnionType1", "ADiscriminatedUnionType2", None, Unset], data)
79+
80+
discriminated_union = _parse_discriminated_union(d.pop("discriminated_union", UNSET))
81+
82+
model_with_discriminated_union = cls(
83+
discriminated_union=discriminated_union,
84+
)
85+
86+
model_with_discriminated_union.additional_properties = d
87+
return model_with_discriminated_union
88+
89+
@property
90+
def additional_keys(self) -> List[str]:
91+
return list(self.additional_properties.keys())
92+
93+
def __getitem__(self, key: str) -> Any:
94+
return self.additional_properties[key]
95+
96+
def __setitem__(self, key: str, value: Any) -> None:
97+
self.additional_properties[key] = value
98+
99+
def __delitem__(self, key: str) -> None:
100+
del self.additional_properties[key]
101+
102+
def __contains__(self, key: str) -> bool:
103+
return key in self.additional_properties

0 commit comments

Comments
 (0)