@@ -35,6 +35,10 @@ pub struct LiteralLookup<T: Debug> {
35
35
expected_py_dict : Option < Py < PyDict > > ,
36
36
// Catch all for unhashable types like list
37
37
expected_py_values : Option < Vec < ( Py < PyAny > , usize ) > > ,
38
+ // Fallback for ints, bools, and strings to use Python hash and equality checks
39
+ // which we can't mix with `expected_py_dict`, as there may be conflicts
40
+ // for an example, see tests/test_validators/test_literal.py::test_mix_int_enum_with_int
41
+ expected_py_primitives : Option < Py < PyDict > > ,
38
42
39
43
pub values : Vec < T > ,
40
44
}
@@ -46,20 +50,24 @@ impl<T: Debug> LiteralLookup<T> {
46
50
let mut expected_str: AHashMap < String , usize > = AHashMap :: new ( ) ;
47
51
let expected_py_dict = PyDict :: new_bound ( py) ;
48
52
let mut expected_py_values = Vec :: new ( ) ;
53
+ let expected_py_primitives = PyDict :: new_bound ( py) ;
49
54
let mut values = Vec :: new ( ) ;
50
55
for ( k, v) in expected {
51
56
let id = values. len ( ) ;
52
57
values. push ( v) ;
53
- if let Ok ( bool) = k. validate_bool ( true ) {
54
- if bool. into_inner ( ) {
58
+
59
+ if let Ok ( bool_value) = k. validate_bool ( true ) {
60
+ if bool_value. into_inner ( ) {
55
61
expected_bool. true_id = Some ( id) ;
56
62
} else {
57
63
expected_bool. false_id = Some ( id) ;
58
64
}
65
+ expected_py_primitives. set_item ( & k, id) ?;
59
66
}
60
67
if k. is_exact_instance_of :: < PyInt > ( ) {
61
68
if let Ok ( int_64) = k. extract :: < i64 > ( ) {
62
69
expected_int. insert ( int_64, id) ;
70
+ expected_py_primitives. set_item ( & k, id) ?;
63
71
} else {
64
72
// cover the case of an int that's > i64::MAX etc.
65
73
expected_py_dict. set_item ( k, id) ?;
@@ -69,32 +77,20 @@ impl<T: Debug> LiteralLookup<T> {
69
77
. as_cow ( )
70
78
. map_err ( |_| py_schema_error_type ! ( "error extracting str {:?}" , k) ) ?;
71
79
expected_str. insert ( str. to_string ( ) , id) ;
80
+ expected_py_primitives. set_item ( & k, id) ?;
72
81
} else if expected_py_dict. set_item ( & k, id) . is_err ( ) {
73
82
expected_py_values. push ( ( k. as_unbound ( ) . clone_ref ( py) , id) ) ;
74
83
}
75
84
}
76
85
77
86
Ok ( Self {
78
- expected_bool : match expected_bool. true_id . is_some ( ) || expected_bool. false_id . is_some ( ) {
79
- true => Some ( expected_bool) ,
80
- false => None ,
81
- } ,
82
- expected_int : match expected_int. is_empty ( ) {
83
- true => None ,
84
- false => Some ( expected_int) ,
85
- } ,
86
- expected_str : match expected_str. is_empty ( ) {
87
- true => None ,
88
- false => Some ( expected_str) ,
89
- } ,
90
- expected_py_dict : match expected_py_dict. is_empty ( ) {
91
- true => None ,
92
- false => Some ( expected_py_dict. into ( ) ) ,
93
- } ,
94
- expected_py_values : match expected_py_values. is_empty ( ) {
95
- true => None ,
96
- false => Some ( expected_py_values) ,
97
- } ,
87
+ expected_bool : ( expected_bool. true_id . is_some ( ) || expected_bool. false_id . is_some ( ) )
88
+ . then_some ( expected_bool) ,
89
+ expected_int : ( !expected_int. is_empty ( ) ) . then_some ( expected_int) ,
90
+ expected_str : ( !expected_str. is_empty ( ) ) . then_some ( expected_str) ,
91
+ expected_py_dict : ( !expected_py_dict. is_empty ( ) ) . then_some ( expected_py_dict. into ( ) ) ,
92
+ expected_py_values : ( !expected_py_values. is_empty ( ) ) . then_some ( expected_py_values) ,
93
+ expected_py_primitives : ( !expected_py_primitives. is_empty ( ) ) . then_some ( expected_py_primitives. into ( ) ) ,
98
94
values,
99
95
} )
100
96
}
@@ -162,6 +158,19 @@ impl<T: Debug> LiteralLookup<T> {
162
158
}
163
159
}
164
160
} ;
161
+
162
+ // this one must be last to avoid conflicts with the other lookups, think of this
163
+ // almost as a lax fallback
164
+ if let Some ( expected_py_primitives) = & self . expected_py_primitives {
165
+ let py_input = py_input. get_or_insert_with ( || input. to_object ( py) ) ;
166
+ // We don't use ? to unpack the result of `get_item` in the next line because unhashable
167
+ // inputs will produce a TypeError, which in this case we just want to treat equivalently
168
+ // to a failed lookup
169
+ if let Ok ( Some ( v) ) = expected_py_primitives. bind ( py) . get_item ( & * py_input) {
170
+ let id: usize = v. extract ( ) . unwrap ( ) ;
171
+ return Ok ( Some ( ( input, & self . values [ id] ) ) ) ;
172
+ }
173
+ } ;
165
174
Ok ( None )
166
175
}
167
176
0 commit comments