Skip to content

Commit 57ba25c

Browse files
sharkdpcarljm
andauthored
[red-knot] Type inference for comparisons involving intersection types (#14138)
## Summary This adds type inference for comparison expressions involving intersection types. For example: ```py x = get_random_int() if x != 42: reveal_type(x == 42) # revealed: Literal[False] reveal_type(x == 43) # bool ``` closes #13854 ## Test Plan New Markdown-based tests. --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent 4f74db5 commit 57ba25c

File tree

2 files changed

+262
-4
lines changed

2 files changed

+262
-4
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Comparison: Intersections
2+
3+
## Positive contributions
4+
5+
If we have an intersection type `A & B` and we get a definitive true/false answer for one of the
6+
types, we can infer that the result for the intersection type is also true/false:
7+
8+
```py
9+
class Base: ...
10+
11+
class Child1(Base):
12+
def __eq__(self, other) -> Literal[True]:
13+
return True
14+
15+
class Child2(Base): ...
16+
17+
def get_base() -> Base: ...
18+
19+
x = get_base()
20+
c1 = Child1()
21+
22+
# Create an intersection type through narrowing:
23+
if isinstance(x, Child1):
24+
if isinstance(x, Child2):
25+
reveal_type(x) # revealed: Child1 & Child2
26+
27+
reveal_type(x == 1) # revealed: Literal[True]
28+
29+
# Other comparison operators fall back to the base type:
30+
reveal_type(x > 1) # revealed: bool
31+
reveal_type(x is c1) # revealed: bool
32+
```
33+
34+
## Negative contributions
35+
36+
Negative contributions to the intersection type only allow simplifications in a few special cases
37+
(equality and identity comparisons).
38+
39+
### Equality comparisons
40+
41+
#### Literal strings
42+
43+
```py
44+
x = "x" * 1_000_000_000
45+
y = "y" * 1_000_000_000
46+
reveal_type(x) # revealed: LiteralString
47+
48+
if x != "abc":
49+
reveal_type(x) # revealed: LiteralString & ~Literal["abc"]
50+
51+
reveal_type(x == "abc") # revealed: Literal[False]
52+
reveal_type("abc" == x) # revealed: Literal[False]
53+
reveal_type(x == "something else") # revealed: bool
54+
reveal_type("something else" == x) # revealed: bool
55+
56+
reveal_type(x != "abc") # revealed: Literal[True]
57+
reveal_type("abc" != x) # revealed: Literal[True]
58+
reveal_type(x != "something else") # revealed: bool
59+
reveal_type("something else" != x) # revealed: bool
60+
61+
reveal_type(x == y) # revealed: bool
62+
reveal_type(y == x) # revealed: bool
63+
reveal_type(x != y) # revealed: bool
64+
reveal_type(y != x) # revealed: bool
65+
66+
reveal_type(x >= "abc") # revealed: bool
67+
reveal_type("abc" >= x) # revealed: bool
68+
69+
reveal_type(x in "abc") # revealed: bool
70+
reveal_type("abc" in x) # revealed: bool
71+
```
72+
73+
#### Integers
74+
75+
```py
76+
def get_int() -> int: ...
77+
78+
x = get_int()
79+
80+
if x != 1:
81+
reveal_type(x) # revealed: int & ~Literal[1]
82+
83+
reveal_type(x != 1) # revealed: Literal[True]
84+
reveal_type(x != 2) # revealed: bool
85+
86+
reveal_type(x == 1) # revealed: Literal[False]
87+
reveal_type(x == 2) # revealed: bool
88+
```
89+
90+
### Identity comparisons
91+
92+
```py
93+
class A: ...
94+
95+
def get_object() -> object: ...
96+
97+
o = object()
98+
99+
a = A()
100+
n = None
101+
102+
if o is not None:
103+
reveal_type(o) # revealed: object & ~None
104+
105+
reveal_type(o is n) # revealed: Literal[False]
106+
reveal_type(o is not n) # revealed: Literal[True]
107+
```
108+
109+
## Diagnostics
110+
111+
### Unsupported operators for positive contributions
112+
113+
Raise an error if any of the positive contributions to the intersection type are unsupported for the
114+
given operator:
115+
116+
```py
117+
class Container:
118+
def __contains__(self, x) -> bool: ...
119+
120+
class NonContainer: ...
121+
122+
def get_object() -> object: ...
123+
124+
x = get_object()
125+
126+
if isinstance(x, Container):
127+
if isinstance(x, NonContainer):
128+
reveal_type(x) # revealed: Container & NonContainer
129+
130+
# error: [unsupported-operator] "Operator `in` is not supported for types `int` and `NonContainer`"
131+
reveal_type(2 in x) # revealed: bool
132+
```
133+
134+
### Unsupported operators for negative contributions
135+
136+
Do *not* raise an error if any of the negative contributions to the intersection type are
137+
unsupported for the given operator:
138+
139+
```py
140+
class Container:
141+
def __contains__(self, x) -> bool: ...
142+
143+
class NonContainer: ...
144+
145+
def get_object() -> object: ...
146+
147+
x = get_object()
148+
149+
if isinstance(x, Container):
150+
if not isinstance(x, NonContainer):
151+
reveal_type(x) # revealed: Container & ~NonContainer
152+
153+
# No error here!
154+
reveal_type(2 in x) # revealed: bool
155+
```

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ use crate::types::unpacker::{UnpackResult, Unpacker};
5757
use crate::types::{
5858
bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol,
5959
Boundness, BytesLiteralType, Class, ClassLiteralType, FunctionType, InstanceType,
60-
IterationOutcome, KnownClass, KnownFunction, KnownInstance, MetaclassErrorKind,
61-
SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, TypeArrayDisplay,
62-
UnionBuilder, UnionType,
60+
IntersectionBuilder, IntersectionType, IterationOutcome, KnownClass, KnownFunction,
61+
KnownInstance, MetaclassErrorKind, SliceLiteralType, StringLiteralType, Symbol, Truthiness,
62+
TupleType, Type, TypeArrayDisplay, UnionBuilder, UnionType,
6363
};
6464
use crate::unpack::Unpack;
6565
use crate::util::subscript::{PyIndex, PySlice};
@@ -266,6 +266,13 @@ impl<'db> TypeInference<'db> {
266266
}
267267
}
268268

269+
/// Whether the intersection type is on the left or right side of the comparison.
270+
#[derive(Debug, Clone, Copy)]
271+
enum IntersectionOn {
272+
Left,
273+
Right,
274+
}
275+
269276
/// Builder to infer all types in a region.
270277
///
271278
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
@@ -3086,7 +3093,7 @@ impl<'db> TypeInferenceBuilder<'db> {
30863093

30873094
// https://docs.python.org/3/reference/expressions.html#comparisons
30883095
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
3089-
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
3096+
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to `a op1 b and b op2 c and
30903097
// ... > y opN z`, except that each expression is evaluated at most once.
30913098
//
30923099
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
@@ -3140,6 +3147,87 @@ impl<'db> TypeInferenceBuilder<'db> {
31403147
)
31413148
}
31423149

3150+
fn infer_binary_intersection_type_comparison(
3151+
&mut self,
3152+
intersection: IntersectionType<'db>,
3153+
op: ast::CmpOp,
3154+
other: Type<'db>,
3155+
intersection_on: IntersectionOn,
3156+
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
3157+
// If a comparison yields a definitive true/false answer on a (positive) part
3158+
// of an intersection type, it will also yield a definitive answer on the full
3159+
// intersection type, which is even more specific.
3160+
for pos in intersection.positive(self.db) {
3161+
let result = match intersection_on {
3162+
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?,
3163+
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?,
3164+
};
3165+
if let Type::BooleanLiteral(b) = result {
3166+
return Ok(Type::BooleanLiteral(b));
3167+
}
3168+
}
3169+
3170+
// For negative contributions to the intersection type, there are only a few
3171+
// special cases that allow us to narrow down the result type of the comparison.
3172+
for neg in intersection.negative(self.db) {
3173+
let result = match intersection_on {
3174+
IntersectionOn::Left => self.infer_binary_type_comparison(*neg, op, other).ok(),
3175+
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *neg).ok(),
3176+
};
3177+
3178+
match (op, result) {
3179+
(ast::CmpOp::Eq, Some(Type::BooleanLiteral(true))) => {
3180+
return Ok(Type::BooleanLiteral(false));
3181+
}
3182+
(ast::CmpOp::NotEq, Some(Type::BooleanLiteral(false))) => {
3183+
return Ok(Type::BooleanLiteral(true));
3184+
}
3185+
(ast::CmpOp::Is, Some(Type::BooleanLiteral(true))) => {
3186+
return Ok(Type::BooleanLiteral(false));
3187+
}
3188+
(ast::CmpOp::IsNot, Some(Type::BooleanLiteral(false))) => {
3189+
return Ok(Type::BooleanLiteral(true));
3190+
}
3191+
_ => {}
3192+
}
3193+
}
3194+
3195+
// If none of the simplifications above apply, we still need to return *some*
3196+
// result type for the comparison 'T_inter `op` T_other' (or reversed), where
3197+
//
3198+
// T_inter = P1 & P2 & ... & Pn & ~N1 & ~N2 & ... & ~Nm
3199+
//
3200+
// is the intersection type. If f(T) is the function that computes the result
3201+
// type of a `op`-comparison with `T_other`, we are interested in f(T_inter).
3202+
// Since we can't compute it exactly, we return the following approximation:
3203+
//
3204+
// f(T_inter) = f(P1) & f(P2) & ... & f(Pn)
3205+
//
3206+
// The reason for this is the following: In general, for any function 'f', the
3207+
// set f(A) & f(B) can be *larger than* the set f(A & B). This means that we
3208+
// will return a type that is too wide, which is not necessarily problematic.
3209+
//
3210+
// However, we do have to leave out the negative contributions. If we were to
3211+
// add a contribution like ~f(N1), we would potentially infer result types
3212+
// that are too narrow, since ~f(A) can be larger than f(~A).
3213+
//
3214+
// As an example for this, consider the intersection type `int & ~Literal[1]`.
3215+
// If 'f' would be the `==`-comparison with 2, we obviously can't tell if that
3216+
// answer would be true or false, so we need to return `bool`. However, if we
3217+
// compute f(int) & ~f(Literal[1]), we get `bool & ~Literal[False]`, which can
3218+
// be simplified to `Literal[True]` -- a type that is too narrow.
3219+
let mut builder = IntersectionBuilder::new(self.db);
3220+
for pos in intersection.positive(self.db) {
3221+
let result = match intersection_on {
3222+
IntersectionOn::Left => self.infer_binary_type_comparison(*pos, op, other)?,
3223+
IntersectionOn::Right => self.infer_binary_type_comparison(other, op, *pos)?,
3224+
};
3225+
builder = builder.add_positive(result);
3226+
}
3227+
3228+
Ok(builder.build())
3229+
}
3230+
31433231
/// Infers the type of a binary comparison (e.g. 'left == right'). See
31443232
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
31453233
/// expressions.
@@ -3172,6 +3260,21 @@ impl<'db> TypeInferenceBuilder<'db> {
31723260
Ok(builder.build())
31733261
}
31743262

3263+
(Type::Intersection(intersection), right) => self
3264+
.infer_binary_intersection_type_comparison(
3265+
intersection,
3266+
op,
3267+
right,
3268+
IntersectionOn::Left,
3269+
),
3270+
(left, Type::Intersection(intersection)) => self
3271+
.infer_binary_intersection_type_comparison(
3272+
intersection,
3273+
op,
3274+
left,
3275+
IntersectionOn::Right,
3276+
),
3277+
31753278
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
31763279
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
31773280
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),

0 commit comments

Comments
 (0)