Skip to content

Commit e882277

Browse files
authored
fix: support generic models with discriminated union (#3551)
1 parent edad0db commit e882277

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

pydantic/fields.py

+4
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,10 @@ def prepare_discriminated_union_sub_fields(self) -> None:
733733
Note that this process can be aborted if a `ForwardRef` is encountered
734734
"""
735735
assert self.discriminator_key is not None
736+
737+
if self.type_.__class__ is DeferredType:
738+
return
739+
736740
assert self.sub_fields is not None
737741
sub_fields_mapping: Dict[str, 'ModelField'] = {}
738742
all_aliases: Set[str] = set()

tests/test_discrimated_union.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import re
2+
import sys
23
from enum import Enum
3-
from typing import Union
4+
from typing import Generic, TypeVar, Union
45

56
import pytest
67
from typing_extensions import Annotated, Literal
78

89
from pydantic import BaseModel, Field, ValidationError
910
from pydantic.errors import ConfigError
11+
from pydantic.generics import GenericModel
1012

1113

1214
def test_discriminated_union_only_union():
@@ -361,3 +363,36 @@ class Model(BaseModel):
361363
n: int
362364

363365
assert isinstance(Model(**{'pet': {'pet_type': 'dog', 'name': 'Milou'}, 'n': 5}).pet, Dog)
366+
367+
368+
@pytest.mark.skipif(sys.version_info < (3, 7), reason='generics only supported for python 3.7 and above')
369+
def test_generic():
370+
T = TypeVar('T')
371+
372+
class Success(GenericModel, Generic[T]):
373+
type: Literal['Success'] = 'Success'
374+
data: T
375+
376+
class Failure(BaseModel):
377+
type: Literal['Failure'] = 'Failure'
378+
error_message: str
379+
380+
class Container(GenericModel, Generic[T]):
381+
result: Union[Success[T], Failure] = Field(discriminator='type')
382+
383+
with pytest.raises(ValidationError, match="Discriminator 'type' is missing in value"):
384+
Container[str].parse_obj({'result': {}})
385+
386+
with pytest.raises(
387+
ValidationError,
388+
match=re.escape("No match for discriminator 'type' and value 'Other' (allowed values: 'Success', 'Failure')"),
389+
):
390+
Container[str].parse_obj({'result': {'type': 'Other'}})
391+
392+
with pytest.raises(
393+
ValidationError, match=re.escape('Container[str]\nresult -> Success[str] -> data\n field required')
394+
):
395+
Container[str].parse_obj({'result': {'type': 'Success'}})
396+
397+
# coercion is done properly
398+
assert Container[str].parse_obj({'result': {'type': 'Success', 'data': 1}}).result.data == '1'

0 commit comments

Comments
 (0)