@@ -57,9 +57,9 @@ use crate::types::unpacker::{UnpackResult, Unpacker};
57
57
use crate :: types:: {
58
58
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol,
59
59
Boundness , BytesLiteralType , Class , ClassLiteralType , FunctionType , InstanceType ,
60
- IterationOutcome , KnownClass , KnownFunction , KnownInstance , MetaclassErrorKind ,
61
- SliceLiteralType , StringLiteralType , Symbol , Truthiness , TupleType , Type , TypeArrayDisplay ,
62
- UnionBuilder , UnionType ,
60
+ IntersectionBuilder , IntersectionType , IterationOutcome , KnownClass , KnownFunction ,
61
+ KnownInstance , MetaclassErrorKind , SliceLiteralType , StringLiteralType , Symbol , Truthiness ,
62
+ TupleType , Type , TypeArrayDisplay , UnionBuilder , UnionType ,
63
63
} ;
64
64
use crate :: unpack:: Unpack ;
65
65
use crate :: util:: subscript:: { PyIndex , PySlice } ;
@@ -266,6 +266,13 @@ impl<'db> TypeInference<'db> {
266
266
}
267
267
}
268
268
269
+ /// Whether the intersection type is on the left or right side of the comparison.
270
+ #[ derive( Debug , Clone , Copy ) ]
271
+ enum IntersectionOn {
272
+ Left ,
273
+ Right ,
274
+ }
275
+
269
276
/// Builder to infer all types in a region.
270
277
///
271
278
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
@@ -3086,7 +3093,7 @@ impl<'db> TypeInferenceBuilder<'db> {
3086
3093
3087
3094
// https://docs.python.org/3/reference/expressions.html#comparisons
3088
3095
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
3089
- // > operators, then `a op1 b op2 c ... y opN z` is equivalent to a ` op1 b and b op2 c and
3096
+ // > operators, then `a op1 b op2 c ... y opN z` is equivalent to `a op1 b and b op2 c and
3090
3097
// ... > y opN z`, except that each expression is evaluated at most once.
3091
3098
//
3092
3099
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
@@ -3140,6 +3147,87 @@ impl<'db> TypeInferenceBuilder<'db> {
3140
3147
)
3141
3148
}
3142
3149
3150
+ fn infer_binary_intersection_type_comparison (
3151
+ & mut self ,
3152
+ intersection : IntersectionType < ' db > ,
3153
+ op : ast:: CmpOp ,
3154
+ other : Type < ' db > ,
3155
+ intersection_on : IntersectionOn ,
3156
+ ) -> Result < Type < ' db > , CompareUnsupportedError < ' db > > {
3157
+ // If a comparison yields a definitive true/false answer on a (positive) part
3158
+ // of an intersection type, it will also yield a definitive answer on the full
3159
+ // intersection type, which is even more specific.
3160
+ for pos in intersection. positive ( self . db ) {
3161
+ let result = match intersection_on {
3162
+ IntersectionOn :: Left => self . infer_binary_type_comparison ( * pos, op, other) ?,
3163
+ IntersectionOn :: Right => self . infer_binary_type_comparison ( other, op, * pos) ?,
3164
+ } ;
3165
+ if let Type :: BooleanLiteral ( b) = result {
3166
+ return Ok ( Type :: BooleanLiteral ( b) ) ;
3167
+ }
3168
+ }
3169
+
3170
+ // For negative contributions to the intersection type, there are only a few
3171
+ // special cases that allow us to narrow down the result type of the comparison.
3172
+ for neg in intersection. negative ( self . db ) {
3173
+ let result = match intersection_on {
3174
+ IntersectionOn :: Left => self . infer_binary_type_comparison ( * neg, op, other) . ok ( ) ,
3175
+ IntersectionOn :: Right => self . infer_binary_type_comparison ( other, op, * neg) . ok ( ) ,
3176
+ } ;
3177
+
3178
+ match ( op, result) {
3179
+ ( ast:: CmpOp :: Eq , Some ( Type :: BooleanLiteral ( true ) ) ) => {
3180
+ return Ok ( Type :: BooleanLiteral ( false ) ) ;
3181
+ }
3182
+ ( ast:: CmpOp :: NotEq , Some ( Type :: BooleanLiteral ( false ) ) ) => {
3183
+ return Ok ( Type :: BooleanLiteral ( true ) ) ;
3184
+ }
3185
+ ( ast:: CmpOp :: Is , Some ( Type :: BooleanLiteral ( true ) ) ) => {
3186
+ return Ok ( Type :: BooleanLiteral ( false ) ) ;
3187
+ }
3188
+ ( ast:: CmpOp :: IsNot , Some ( Type :: BooleanLiteral ( false ) ) ) => {
3189
+ return Ok ( Type :: BooleanLiteral ( true ) ) ;
3190
+ }
3191
+ _ => { }
3192
+ }
3193
+ }
3194
+
3195
+ // If none of the simplifications above apply, we still need to return *some*
3196
+ // result type for the comparison 'T_inter `op` T_other' (or reversed), where
3197
+ //
3198
+ // T_inter = P1 & P2 & ... & Pn & ~N1 & ~N2 & ... & ~Nm
3199
+ //
3200
+ // is the intersection type. If f(T) is the function that computes the result
3201
+ // type of a `op`-comparison with `T_other`, we are interested in f(T_inter).
3202
+ // Since we can't compute it exactly, we return the following approximation:
3203
+ //
3204
+ // f(T_inter) = f(P1) & f(P2) & ... & f(Pn)
3205
+ //
3206
+ // The reason for this is the following: In general, for any function 'f', the
3207
+ // set f(A) & f(B) can be *larger than* the set f(A & B). This means that we
3208
+ // will return a type that is too wide, which is not necessarily problematic.
3209
+ //
3210
+ // However, we do have to leave out the negative contributions. If we were to
3211
+ // add a contribution like ~f(N1), we would potentially infer result types
3212
+ // that are too narrow, since ~f(A) can be larger than f(~A).
3213
+ //
3214
+ // As an example for this, consider the intersection type `int & ~Literal[1]`.
3215
+ // If 'f' would be the `==`-comparison with 2, we obviously can't tell if that
3216
+ // answer would be true or false, so we need to return `bool`. However, if we
3217
+ // compute f(int) & ~f(Literal[1]), we get `bool & ~Literal[False]`, which can
3218
+ // be simplified to `Literal[True]` -- a type that is too narrow.
3219
+ let mut builder = IntersectionBuilder :: new ( self . db ) ;
3220
+ for pos in intersection. positive ( self . db ) {
3221
+ let result = match intersection_on {
3222
+ IntersectionOn :: Left => self . infer_binary_type_comparison ( * pos, op, other) ?,
3223
+ IntersectionOn :: Right => self . infer_binary_type_comparison ( other, op, * pos) ?,
3224
+ } ;
3225
+ builder = builder. add_positive ( result) ;
3226
+ }
3227
+
3228
+ Ok ( builder. build ( ) )
3229
+ }
3230
+
3143
3231
/// Infers the type of a binary comparison (e.g. 'left == right'). See
3144
3232
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
3145
3233
/// expressions.
@@ -3172,6 +3260,21 @@ impl<'db> TypeInferenceBuilder<'db> {
3172
3260
Ok ( builder. build ( ) )
3173
3261
}
3174
3262
3263
+ ( Type :: Intersection ( intersection) , right) => self
3264
+ . infer_binary_intersection_type_comparison (
3265
+ intersection,
3266
+ op,
3267
+ right,
3268
+ IntersectionOn :: Left ,
3269
+ ) ,
3270
+ ( left, Type :: Intersection ( intersection) ) => self
3271
+ . infer_binary_intersection_type_comparison (
3272
+ intersection,
3273
+ op,
3274
+ left,
3275
+ IntersectionOn :: Right ,
3276
+ ) ,
3277
+
3175
3278
( Type :: IntLiteral ( n) , Type :: IntLiteral ( m) ) => match op {
3176
3279
ast:: CmpOp :: Eq => Ok ( Type :: BooleanLiteral ( n == m) ) ,
3177
3280
ast:: CmpOp :: NotEq => Ok ( Type :: BooleanLiteral ( n != m) ) ,
0 commit comments