-
Notifications
You must be signed in to change notification settings - Fork 421
/
Copy pathcompat.py
326 lines (241 loc) · 10.5 KB
/
compat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# mypy: ignore-errors
from __future__ import annotations
from collections import deque
from collections.abc import Mapping, Sequence
# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
# versions of a module, so we need to ignore errors here.
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union
from pydantic import BaseModel, TypeAdapter, ValidationError, create_model
# Importing from internal libraries in Pydantic may introduce potential risks, as these internal libraries
# are not part of the public API and may change without notice in future releases.
# We use this for forward reference, as it allows us to handle forward references in type annotations.
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic._internal._utils import lenient_issubclass
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from typing_extensions import Annotated, Literal, get_args, get_origin
from aws_lambda_powertools.event_handler.openapi.types import UnionType
if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from aws_lambda_powertools.event_handler.openapi.types import IncEx, ModelNameMap
Undefined = PydanticUndefined
Required = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
RequestErrorModel: type[BaseModel] = create_model("Request")
class ErrorWrapper(Exception):
pass
@dataclass
class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
@property
def alias(self) -> str:
value = self.field_info.alias
return value if value is not None else self.name
@property
def required(self) -> bool:
return self.field_info.is_required()
@property
def default(self) -> Any:
return self.get_default()
@property
def type_(self) -> Any:
return self.field_info.annotation
def __post_init__(self) -> None:
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info],
)
def get_default(self) -> Any:
if self.field_info.is_required():
return Undefined
return self.field_info.get_default(call_default_factory=True)
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
return self._type_adapter.dump_python(
value,
mode=mode,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
def validate(
self,
value: Any,
*,
loc: tuple[int | str, ...] = (),
) -> tuple[Any, list[dict[str, Any]] | None]:
try:
return (self._type_adapter.validate_python(value, from_attributes=True), None)
except ValidationError as exc:
return None, _regenerate_error_with_loc(errors=exc.errors(), loc_prefix=loc)
def __hash__(self) -> int:
# Each ModelField is unique for our purposes
return id(self)
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]],
JsonSchemaValue,
],
) -> dict[str, Any]:
json_schema = field_mapping[(field, field.mode)]
if "$ref" not in json_schema:
# MAINTENANCE: remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field.alias.title().replace("_", " ")
return json_schema
def get_definitions(
*,
fields: list[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
) -> tuple[
dict[
tuple[ModelField, Literal["validation", "serialization"]],
dict[str, Any],
],
dict[str, dict[str, Any]],
]:
inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
return field_mapping, definitions
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
return {}
def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
return annotation
def model_rebuild(model: type[BaseModel]) -> None:
model.model_rebuild()
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required",
[{"type": "missing", "loc": loc, "input": {}}],
).errors()[0]
error["input"] = None
return error
def is_scalar_field(field: ModelField) -> bool:
from aws_lambda_powertools.event_handler.openapi.params import Body
return field_annotation_is_scalar(field.field_info.annotation) and not isinstance(field.field_info, Body)
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return is_bytes_sequence_annotation(field.type_)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
if not issubclass(origin_type, sequence_types): # type: ignore[arg-type]
raise AssertionError(f"Expected sequence type, got {origin_type}")
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
return errors # type: ignore[return-value]
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
model: type[BaseModel] = create_model(model_name, **field_params)
return model
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)
def model_json(model: BaseModel, **kwargs: Any) -> Any:
return model.model_dump_json(**kwargs)
# Common code for both versions
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
return (
_annotation_is_complex(annotation)
or _annotation_is_complex(origin)
or hasattr(origin, "__pydantic_core_schema__")
or hasattr(origin, "__get_pydantic_core_schema__")
)
def field_annotation_is_scalar(annotation: Any) -> bool:
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation))
def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
for arg in get_args(annotation):
if field_annotation_is_scalar_sequence(arg):
at_least_one_scalar_sequence = True
continue
elif not field_annotation_is_scalar(arg):
return False
return at_least_one_scalar_sequence
return field_annotation_is_sequence(annotation) and all(
field_annotation_is_scalar(sub_annotation) for sub_annotation in get_args(annotation)
)
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
if lenient_issubclass(annotation, bytes):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, bytes):
return True
return False
def is_bytes_sequence_annotation(annotation: Any) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one = False
for arg in get_args(annotation):
if is_bytes_sequence_annotation(arg):
at_least_one = True
break
return at_least_one
return field_annotation_is_sequence(annotation) and all(
is_bytes_or_nonable_bytes_annotation(sub_annotation) for sub_annotation in get_args(annotation)
)
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping)) # Keep it to UploadFile
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types)
def _regenerate_error_with_loc(*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors)
]
return updated_loc_errors