Skip to content

Commit f389728

Browse files
Fix equality checks for primitives in literals (#1459)
1 parent 4aa52a8 commit f389728

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

src/validators/literal.rs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ pub struct LiteralLookup<T: Debug> {
3535
expected_py_dict: Option<Py<PyDict>>,
3636
// Catch all for unhashable types like list
3737
expected_py_values: Option<Vec<(Py<PyAny>, usize)>>,
38+
// Fallback for ints, bools, and strings to use Python hash and equality checks
39+
// which we can't mix with `expected_py_dict`, as there may be conflicts
40+
// for an example, see tests/test_validators/test_literal.py::test_mix_int_enum_with_int
41+
expected_py_primitives: Option<Py<PyDict>>,
3842

3943
pub values: Vec<T>,
4044
}
@@ -46,20 +50,24 @@ impl<T: Debug> LiteralLookup<T> {
4650
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
4751
let expected_py_dict = PyDict::new_bound(py);
4852
let mut expected_py_values = Vec::new();
53+
let expected_py_primitives = PyDict::new_bound(py);
4954
let mut values = Vec::new();
5055
for (k, v) in expected {
5156
let id = values.len();
5257
values.push(v);
53-
if let Ok(bool) = k.validate_bool(true) {
54-
if bool.into_inner() {
58+
59+
if let Ok(bool_value) = k.validate_bool(true) {
60+
if bool_value.into_inner() {
5561
expected_bool.true_id = Some(id);
5662
} else {
5763
expected_bool.false_id = Some(id);
5864
}
65+
expected_py_primitives.set_item(&k, id)?;
5966
}
6067
if k.is_exact_instance_of::<PyInt>() {
6168
if let Ok(int_64) = k.extract::<i64>() {
6269
expected_int.insert(int_64, id);
70+
expected_py_primitives.set_item(&k, id)?;
6371
} else {
6472
// cover the case of an int that's > i64::MAX etc.
6573
expected_py_dict.set_item(k, id)?;
@@ -69,32 +77,20 @@ impl<T: Debug> LiteralLookup<T> {
6977
.as_cow()
7078
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
7179
expected_str.insert(str.to_string(), id);
80+
expected_py_primitives.set_item(&k, id)?;
7281
} else if expected_py_dict.set_item(&k, id).is_err() {
7382
expected_py_values.push((k.as_unbound().clone_ref(py), id));
7483
}
7584
}
7685

7786
Ok(Self {
78-
expected_bool: match expected_bool.true_id.is_some() || expected_bool.false_id.is_some() {
79-
true => Some(expected_bool),
80-
false => None,
81-
},
82-
expected_int: match expected_int.is_empty() {
83-
true => None,
84-
false => Some(expected_int),
85-
},
86-
expected_str: match expected_str.is_empty() {
87-
true => None,
88-
false => Some(expected_str),
89-
},
90-
expected_py_dict: match expected_py_dict.is_empty() {
91-
true => None,
92-
false => Some(expected_py_dict.into()),
93-
},
94-
expected_py_values: match expected_py_values.is_empty() {
95-
true => None,
96-
false => Some(expected_py_values),
97-
},
87+
expected_bool: (expected_bool.true_id.is_some() || expected_bool.false_id.is_some())
88+
.then_some(expected_bool),
89+
expected_int: (!expected_int.is_empty()).then_some(expected_int),
90+
expected_str: (!expected_str.is_empty()).then_some(expected_str),
91+
expected_py_dict: (!expected_py_dict.is_empty()).then_some(expected_py_dict.into()),
92+
expected_py_values: (!expected_py_values.is_empty()).then_some(expected_py_values),
93+
expected_py_primitives: (!expected_py_primitives.is_empty()).then_some(expected_py_primitives.into()),
9894
values,
9995
})
10096
}
@@ -162,6 +158,19 @@ impl<T: Debug> LiteralLookup<T> {
162158
}
163159
}
164160
};
161+
162+
// this one must be last to avoid conflicts with the other lookups, think of this
163+
// almost as a lax fallback
164+
if let Some(expected_py_primitives) = &self.expected_py_primitives {
165+
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
166+
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
167+
// inputs will produce a TypeError, which in this case we just want to treat equivalently
168+
// to a failed lookup
169+
if let Ok(Some(v)) = expected_py_primitives.bind(py).get_item(&*py_input) {
170+
let id: usize = v.extract().unwrap();
171+
return Ok(Some((input, &self.values[id])));
172+
}
173+
};
165174
Ok(None)
166175
}
167176

tests/validators/test_literal.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,15 @@ def test_big_int():
389389
m = r'Input should be 18446744073709551617 or 340282366920938463463374607431768211457 \[type=literal_error'
390390
with pytest.raises(ValidationError, match=m):
391391
v.validate_python(37)
392+
393+
394+
def test_enum_for_str() -> None:
395+
class S(str, Enum):
396+
a = 'a'
397+
398+
val_enum = SchemaValidator(core_schema.literal_schema([S.a]))
399+
val_str = SchemaValidator(core_schema.literal_schema(['a']))
400+
401+
for val in [val_enum, val_str]:
402+
assert val.validate_python('a') == 'a'
403+
assert val.validate_python(S.a) == 'a'

0 commit comments

Comments
 (0)