Skip to content

Commit 66f2890

Browse files
authored
is_instance JSON support (#278)
* is_instance json support * add benchmark, remove Option from json_type * remove defult impl of is_instance * confirm tuples work too * test for json string not is_instance
1 parent c13283d commit 66f2890

File tree

11 files changed

+182
-19
lines changed

11 files changed

+182
-19
lines changed

generate_self_schema.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Callable
1212
from datetime import date, datetime, time, timedelta
1313
from pathlib import Path
14-
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Type, Union
14+
from typing import TYPE_CHECKING, Any, Dict, ForwardRef, List, Set, Type, Union
1515

1616
from typing_extensions import get_args, get_origin, is_typeddict
1717

@@ -67,6 +67,8 @@ def get_schema(obj) -> core_schema.CoreSchema:
6767
return {'type': 'literal', 'expected': expected}
6868
elif issubclass(origin, List):
6969
return {'type': 'list', 'items_schema': get_schema(obj.__args__[0])}
70+
elif issubclass(origin, Set):
71+
return {'type': 'set', 'items_schema': get_schema(obj.__args__[0])}
7072
elif issubclass(origin, Dict):
7173
return {
7274
'type': 'dict',

pydantic_core/core_schema.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import sys
44
from datetime import date, datetime, time, timedelta
5-
from typing import Any, Callable, Dict, List, Optional, Type, Union, overload
5+
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union, overload
66

77
if sys.version_info < (3, 11):
88
from typing_extensions import NotRequired, Protocol, Required
@@ -278,13 +278,21 @@ def literal_schema(*expected: Any, ref: str | None = None) -> LiteralSchema:
278278
return dict_not_none(type='literal', expected=expected, ref=ref)
279279

280280

281-
class IsInstanceSchema(TypedDict):
282-
type: Literal['is-instance']
283-
cls: Type[Any]
281+
# must match input/parse_json.rs::JsonType::try_from
282+
JsonType = Literal['null', 'bool', 'int', 'float', 'str', 'list', 'dict']
283+
284+
285+
class IsInstanceSchema(TypedDict, total=False):
286+
type: Required[Literal['is-instance']]
287+
cls: Required[Type[Any]]
288+
json_types: Set[JsonType]
289+
ref: str
284290

285291

286-
def is_instance_schema(cls: Type[Any]) -> IsInstanceSchema:
287-
return dict_not_none(type='is-instance', cls=cls)
292+
def is_instance_schema(
293+
cls: Type[Any], *, json_types: Set[JsonType] | None = None, ref: str | None = None
294+
) -> IsInstanceSchema:
295+
return dict_not_none(type='is-instance', cls=cls, json_types=json_types, ref=ref)
288296

289297

290298
class CallableSchema(TypedDict):

src/errors/location.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ impl ToPyObject for LocItem {
5050
}
5151
}
5252

53-
impl LocItem {
54-
pub fn try_from(value: &PyAny) -> PyResult<Self> {
53+
impl TryFrom<&PyAny> for LocItem {
54+
type Error = PyErr;
55+
56+
fn try_from(value: &PyAny) -> PyResult<Self> {
5557
if let Ok(str) = value.extract::<String>() {
5658
Ok(str.into())
5759
} else {

src/input/input_abstract.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
3434
None
3535
}
3636

37-
fn is_instance(&self, _class: &PyType) -> PyResult<bool> {
38-
Ok(false)
39-
}
37+
fn is_instance(&self, class: &PyType, json_mask: u8) -> PyResult<bool>;
4038

4139
fn callable(&self) -> bool {
4240
false

src/input/input_json.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
use pyo3::prelude::*;
2+
use pyo3::types::PyType;
3+
14
use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult};
5+
use crate::input::JsonType;
26

37
use super::datetime::{
48
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
@@ -30,6 +34,23 @@ impl<'a> Input<'a> for JsonInput {
3034
matches!(self, JsonInput::Null)
3135
}
3236

37+
fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult<bool> {
38+
if json_mask == 0 {
39+
Ok(false)
40+
} else {
41+
let json_type: JsonType = match self {
42+
JsonInput::Null => JsonType::Null,
43+
JsonInput::Bool(_) => JsonType::Bool,
44+
JsonInput::Int(_) => JsonType::Int,
45+
JsonInput::Float(_) => JsonType::Float,
46+
JsonInput::String(_) => JsonType::String,
47+
JsonInput::Array(_) => JsonType::Array,
48+
JsonInput::Object(_) => JsonType::Object,
49+
};
50+
Ok(json_type.matches(json_mask))
51+
}
52+
}
53+
3354
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
3455
match self {
3556
JsonInput::Object(kwargs) => Ok(JsonArgs::new(None, Some(kwargs)).into()),
@@ -284,6 +305,14 @@ impl<'a> Input<'a> for String {
284305
false
285306
}
286307

308+
fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult<bool> {
309+
if json_mask == 0 {
310+
Ok(false)
311+
} else {
312+
Ok(JsonType::String.matches(json_mask))
313+
}
314+
}
315+
287316
#[cfg_attr(has_no_coverage, no_coverage)]
288317
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
289318
Err(ValError::new(ErrorKind::ArgumentsType, self))

src/input/input_python.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl<'a> Input<'a> for PyAny {
8181
self.getattr(name).ok()
8282
}
8383

84-
fn is_instance(&self, class: &PyType) -> PyResult<bool> {
84+
fn is_instance(&self, class: &PyType, _json_mask: u8) -> PyResult<bool> {
8585
self.is_instance(class)
8686
}
8787

src/input/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ mod shared;
1010

1111
pub use datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1212
pub use input_abstract::Input;
13-
pub use parse_json::{JsonInput, JsonObject};
13+
pub use parse_json::{JsonInput, JsonObject, JsonType};
1414
pub use return_enums::{
1515
py_string_str, EitherBytes, EitherString, GenericArguments, GenericCollection, GenericIterator, GenericMapping,
1616
JsonArgs, PyArgs,

src/input/parse_json.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,50 @@ use std::fmt;
22

33
use indexmap::IndexMap;
44
use pyo3::prelude::*;
5-
use pyo3::types::PyDict;
5+
use pyo3::types::{PyDict, PySet};
66
use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor};
77

8+
use crate::build_tools::py_error;
9+
10+
#[derive(Clone, Debug)]
11+
pub enum JsonType {
12+
Null = 0b10000000,
13+
Bool = 0b01000000,
14+
Int = 0b00100000,
15+
Float = 0b00010000,
16+
String = 0b00001000,
17+
Array = 0b00000100,
18+
Object = 0b00000010,
19+
}
20+
21+
impl JsonType {
22+
pub fn combine(set: &PySet) -> PyResult<u8> {
23+
set.iter().map(Self::try_from).try_fold(0u8, |a, b| Ok(a | b? as u8))
24+
}
25+
26+
pub fn matches(&self, mask: u8) -> bool {
27+
*self as u8 & mask > 0
28+
}
29+
}
30+
31+
impl TryFrom<&PyAny> for JsonType {
32+
type Error = PyErr;
33+
34+
fn try_from(value: &PyAny) -> PyResult<Self> {
35+
let s: &str = value.extract()?;
36+
match s {
37+
"null" => Ok(Self::Null),
38+
"bool" => Ok(Self::Bool),
39+
"int" => Ok(Self::Int),
40+
"float" => Ok(Self::Float),
41+
"str" => Ok(Self::String),
42+
"list" => Ok(Self::Array),
43+
"dict" => Ok(Self::Object),
44+
_ => py_error!("Invalid json type: {}", s),
45+
}
46+
}
47+
}
48+
849
/// similar to serde `Value` but with int and float split
950
#[derive(Clone, Debug)]
1051
pub enum JsonInput {

src/validators/is_instance.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
3-
use pyo3::types::{PyDict, PyType};
3+
use pyo3::types::{PyDict, PySet, PyType};
44

55
use crate::build_tools::SchemaDict;
66
use crate::errors::{ErrorKind, ValError, ValResult};
7-
use crate::input::Input;
7+
use crate::input::{Input, JsonType};
88
use crate::recursion_guard::RecursionGuard;
99

1010
use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1111

1212
#[derive(Debug, Clone)]
1313
pub struct IsInstanceValidator {
1414
class: Py<PyType>,
15+
json_types: u8,
1516
class_repr: String,
1617
name: String,
1718
}
@@ -27,8 +28,13 @@ impl BuildValidator for IsInstanceValidator {
2728
let class: &PyType = schema.get_as_req(intern!(schema.py(), "cls"))?;
2829
let class_repr = class.name()?.to_string();
2930
let name = format!("{}[{}]", Self::EXPECTED_TYPE, class_repr);
31+
let json_types = match schema.get_as::<&PySet>(intern!(schema.py(), "json_types"))? {
32+
Some(s) => JsonType::combine(s)?,
33+
None => 0,
34+
};
3035
Ok(Self {
3136
class: class.into(),
37+
json_types,
3238
class_repr,
3339
name,
3440
}
@@ -45,7 +51,7 @@ impl Validator for IsInstanceValidator {
4551
_slots: &'data [CombinedValidator],
4652
_recursion_guard: &'s mut RecursionGuard,
4753
) -> ValResult<'data, PyObject> {
48-
match input.is_instance(self.class.as_ref(py))? {
54+
match input.is_instance(self.class.as_ref(py), self.json_types)? {
4955
true => Ok(input.to_object(py)),
5056
false => Err(ValError::new(
5157
ErrorKind::IsInstanceOf {

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,3 +1003,15 @@ def test_generator_rust(benchmark):
10031003
assert sum(v.validate_python(input_value)) == 4950
10041004

10051005
benchmark(v.validate_python, input_value)
1006+
1007+
1008+
@pytest.mark.benchmark(group='isinstance-json')
1009+
def test_isinstance_json(benchmark):
1010+
validator = SchemaValidator(core_schema.is_instance_schema(str, json_types={'str'}))
1011+
assert validator.isinstance_json('"foo"') is True
1012+
assert validator.isinstance_json('123') is False
1013+
1014+
@benchmark
1015+
def t():
1016+
validator.isinstance_json('"foo"')
1017+
validator.isinstance_json('123')

tests/validators/test_is_instance.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22

3-
from pydantic_core import SchemaError, SchemaValidator, ValidationError
3+
from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema
4+
5+
from ..conftest import plain_repr
46

57

68
class Foo:
@@ -118,3 +120,66 @@ def test_repr():
118120
def test_is_type(input_val, value):
119121
v = SchemaValidator({'type': 'is-instance', 'cls': type})
120122
assert v.isinstance_python(input_val) == value
123+
124+
125+
@pytest.mark.parametrize(
126+
'input_val,expected',
127+
[
128+
('null', False),
129+
('true', True),
130+
('1', False),
131+
('1.1', False),
132+
('"a string"', True),
133+
('["s"]', False),
134+
('{"s": 1}', False),
135+
],
136+
)
137+
def test_is_instance_json_string_bool(input_val, expected):
138+
v = SchemaValidator(core_schema.is_instance_schema(Foo, json_types={'str', 'bool'}))
139+
assert v.isinstance_json(input_val) == expected
140+
141+
142+
@pytest.mark.parametrize(
143+
'input_val,expected',
144+
[
145+
('null', False),
146+
('true', False),
147+
('1', False),
148+
('1.1', False),
149+
('"a string"', False),
150+
('["s"]', True),
151+
('{"s": 1}', False),
152+
],
153+
)
154+
def test_is_instance_json_list(input_val, expected):
155+
v = SchemaValidator(core_schema.is_instance_schema(Foo, json_types=('list',)))
156+
assert v.isinstance_json(input_val) == expected
157+
158+
159+
def test_is_instance_dict():
160+
v = SchemaValidator(
161+
core_schema.dict_schema(
162+
keys_schema=core_schema.is_instance_schema(str, json_types={'str'}),
163+
values_schema=core_schema.is_instance_schema(int, json_types={'int', 'dict'}),
164+
)
165+
)
166+
assert v.isinstance_python({'foo': 1}) is True
167+
assert v.isinstance_python({1: 1}) is False
168+
assert v.isinstance_json('{"foo": 1}') is True
169+
assert v.isinstance_json('{"foo": "1"}') is False
170+
assert v.isinstance_json('{"foo": {"a": 1}}') is True
171+
172+
173+
def test_is_instance_dict_not_str():
174+
v = SchemaValidator(core_schema.dict_schema(keys_schema=core_schema.is_instance_schema(int, json_types={'int'})))
175+
assert v.isinstance_python({1: 1}) is True
176+
assert v.isinstance_python({'foo': 1}) is False
177+
assert v.isinstance_json('{"foo": 1}') is False
178+
179+
180+
def test_json_mask():
181+
assert 'json_types:128' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str, json_types={'null'})))
182+
assert 'json_types:0' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str)))
183+
assert 'json_types:0' in plain_repr(SchemaValidator(core_schema.is_instance_schema(str, json_types=set())))
184+
v = SchemaValidator(core_schema.is_instance_schema(str, json_types={'list', 'dict'}))
185+
assert 'json_types:6' in plain_repr(v) # 2 + 4

0 commit comments

Comments
 (0)