Skip to content

Commit e50fc04

Browse files
[red-knot] visibility_constraint analysis for match cases (#17077)
## Summary Add visibility constraint analysis for pattern predicate kinds `Singleton`, `Or`, and `Class`. ## Test Plan update conditional/match.md
1 parent 66355a6 commit e50fc04

File tree

2 files changed

+327
-27
lines changed

2 files changed

+327
-27
lines changed

crates/red_knot_python_semantic/resources/mdtest/conditional/match.md

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,240 @@ def _(target: int):
4444
reveal_type(y) # revealed: Literal[2, 3, 4]
4545
```
4646

47+
## Value match
48+
49+
A value pattern matches based on equality: the first `case` branch here will be taken if `subject`
50+
is equal to `2`, even if `subject` is not an instance of `int`. We can't know whether `C` here has a
51+
custom `__eq__` implementation that might cause it to compare equal to `2`, so we have to consider
52+
the possibility that the `case` branch might be taken even though the type `C` is disjoint from the
53+
type `Literal[2]`.
54+
55+
This leads us to infer `Literal[1, 3]` as the type of `y` after the `match` statement, rather than
56+
`Literal[1]`:
57+
58+
```py
59+
from typing import final
60+
61+
@final
62+
class C:
63+
pass
64+
65+
def _(subject: C):
66+
y = 1
67+
match subject:
68+
case 2:
69+
y = 3
70+
reveal_type(y) # revealed: Literal[1, 3]
71+
```
72+
73+
## Class match
74+
75+
A `case` branch with a class pattern is taken if the subject is an instance of the given class, and
76+
all subpatterns in the class pattern match.
77+
78+
```py
79+
from typing import final
80+
81+
class Foo:
82+
pass
83+
84+
class FooSub(Foo):
85+
pass
86+
87+
class Bar:
88+
pass
89+
90+
@final
91+
class Baz:
92+
pass
93+
94+
def _(target: FooSub):
95+
y = 1
96+
97+
match target:
98+
case Baz():
99+
y = 2
100+
case Foo():
101+
y = 3
102+
case Bar():
103+
y = 4
104+
105+
reveal_type(y) # revealed: Literal[3]
106+
107+
def _(target: FooSub):
108+
y = 1
109+
110+
match target:
111+
case Baz():
112+
y = 2
113+
case Bar():
114+
y = 3
115+
case Foo():
116+
y = 4
117+
118+
reveal_type(y) # revealed: Literal[3, 4]
119+
120+
def _(target: FooSub | str):
121+
y = 1
122+
123+
match target:
124+
case Baz():
125+
y = 2
126+
case Foo():
127+
y = 3
128+
case Bar():
129+
y = 4
130+
131+
reveal_type(y) # revealed: Literal[1, 3, 4]
132+
```
133+
134+
## Singleton match
135+
136+
Singleton patterns are matched based on identity, not equality comparisons or `isinstance()` checks.
137+
138+
```py
139+
from typing import Literal
140+
141+
def _(target: Literal[True, False]):
142+
y = 1
143+
144+
match target:
145+
case True:
146+
y = 2
147+
case False:
148+
y = 3
149+
case None:
150+
y = 4
151+
152+
# TODO: with exhaustiveness checking, this should be Literal[2, 3]
153+
reveal_type(y) # revealed: Literal[1, 2, 3]
154+
155+
def _(target: bool):
156+
y = 1
157+
158+
match target:
159+
case True:
160+
y = 2
161+
case False:
162+
y = 3
163+
case None:
164+
y = 4
165+
166+
# TODO: with exhaustiveness checking, this should be Literal[2, 3]
167+
reveal_type(y) # revealed: Literal[1, 2, 3]
168+
169+
def _(target: None):
170+
y = 1
171+
172+
match target:
173+
case True:
174+
y = 2
175+
case False:
176+
y = 3
177+
case None:
178+
y = 4
179+
180+
reveal_type(y) # revealed: Literal[4]
181+
182+
def _(target: None | Literal[True]):
183+
y = 1
184+
185+
match target:
186+
case True:
187+
y = 2
188+
case False:
189+
y = 3
190+
case None:
191+
y = 4
192+
193+
# TODO: with exhaustiveness checking, this should be Literal[2, 4]
194+
reveal_type(y) # revealed: Literal[1, 2, 4]
195+
196+
# bool is an int subclass
197+
def _(target: int):
198+
y = 1
199+
200+
match target:
201+
case True:
202+
y = 2
203+
case False:
204+
y = 3
205+
case None:
206+
y = 4
207+
208+
reveal_type(y) # revealed: Literal[1, 2, 3]
209+
210+
def _(target: str):
211+
y = 1
212+
213+
match target:
214+
case True:
215+
y = 2
216+
case False:
217+
y = 3
218+
case None:
219+
y = 4
220+
221+
reveal_type(y) # revealed: Literal[1]
222+
```
223+
224+
## Or match
225+
226+
A `|` pattern matches if any of the subpatterns match.
227+
228+
```py
229+
from typing import Literal, final
230+
231+
def _(target: Literal["foo", "baz"]):
232+
y = 1
233+
234+
match target:
235+
case "foo" | "bar":
236+
y = 2
237+
case "baz":
238+
y = 3
239+
240+
# TODO: with exhaustiveness, this should be Literal[2, 3]
241+
reveal_type(y) # revealed: Literal[1, 2, 3]
242+
243+
def _(target: None):
244+
y = 1
245+
246+
match target:
247+
case None | 3:
248+
y = 2
249+
case "foo" | 4 | True:
250+
y = 3
251+
252+
reveal_type(y) # revealed: Literal[2]
253+
254+
@final
255+
class Baz:
256+
pass
257+
258+
def _(target: int | None | float):
259+
y = 1
260+
261+
match target:
262+
case None | 3:
263+
y = 2
264+
case Baz():
265+
y = 3
266+
267+
reveal_type(y) # revealed: Literal[1, 2]
268+
269+
def _(target: None | str):
270+
y = 1
271+
272+
match target:
273+
case Baz() | True | False:
274+
y = 2
275+
case int():
276+
y = 3
277+
278+
reveal_type(y) # revealed: Literal[1, 3]
279+
```
280+
47281
## Guard with object that implements `__bool__` incorrectly
48282

49283
```py

crates/red_knot_python_semantic/src/semantic_index/visibility_constraints.rs

Lines changed: 93 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,11 @@ use std::cmp::Ordering;
178178
use ruff_index::{Idx, IndexVec};
179179
use rustc_hash::FxHashMap;
180180

181+
use crate::semantic_index::expression::Expression;
181182
use crate::semantic_index::predicate::{
182-
PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId,
183+
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId,
183184
};
184-
use crate::types::{infer_expression_type, Truthiness};
185+
use crate::types::{infer_expression_type, Truthiness, Type};
185186
use crate::Db;
186187

187188
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@@ -553,37 +554,102 @@ impl VisibilityConstraints {
553554
}
554555
}
555556

557+
fn analyze_single_pattern_predicate_kind<'db>(
558+
db: &'db dyn Db,
559+
predicate_kind: &PatternPredicateKind<'db>,
560+
subject: Expression<'db>,
561+
) -> Truthiness {
562+
match predicate_kind {
563+
PatternPredicateKind::Value(value) => {
564+
let subject_ty = infer_expression_type(db, subject);
565+
let value_ty = infer_expression_type(db, *value);
566+
567+
if subject_ty.is_single_valued(db) {
568+
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty))
569+
} else {
570+
Truthiness::Ambiguous
571+
}
572+
}
573+
PatternPredicateKind::Singleton(singleton) => {
574+
let subject_ty = infer_expression_type(db, subject);
575+
576+
let singleton_ty = match singleton {
577+
ruff_python_ast::Singleton::None => Type::none(db),
578+
ruff_python_ast::Singleton::True => Type::BooleanLiteral(true),
579+
ruff_python_ast::Singleton::False => Type::BooleanLiteral(false),
580+
};
581+
582+
debug_assert!(singleton_ty.is_singleton(db));
583+
584+
if subject_ty.is_equivalent_to(db, singleton_ty) {
585+
Truthiness::AlwaysTrue
586+
} else if subject_ty.is_disjoint_from(db, singleton_ty) {
587+
Truthiness::AlwaysFalse
588+
} else {
589+
Truthiness::Ambiguous
590+
}
591+
}
592+
PatternPredicateKind::Or(predicates) => {
593+
use std::ops::ControlFlow;
594+
let (ControlFlow::Break(truthiness) | ControlFlow::Continue(truthiness)) =
595+
predicates
596+
.iter()
597+
.map(|p| Self::analyze_single_pattern_predicate_kind(db, p, subject))
598+
// this is just a "max", but with a slight optimization: `AlwaysTrue` is the "greatest" possible element, so we short-circuit if we get there
599+
.try_fold(Truthiness::AlwaysFalse, |acc, next| match (acc, next) {
600+
(Truthiness::AlwaysTrue, _) | (_, Truthiness::AlwaysTrue) => {
601+
ControlFlow::Break(Truthiness::AlwaysTrue)
602+
}
603+
(Truthiness::Ambiguous, _) | (_, Truthiness::Ambiguous) => {
604+
ControlFlow::Continue(Truthiness::Ambiguous)
605+
}
606+
(Truthiness::AlwaysFalse, Truthiness::AlwaysFalse) => {
607+
ControlFlow::Continue(Truthiness::AlwaysFalse)
608+
}
609+
});
610+
truthiness
611+
}
612+
PatternPredicateKind::Class(class_expr) => {
613+
let subject_ty = infer_expression_type(db, subject);
614+
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);
615+
616+
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
617+
if subject_ty.is_subtype_of(db, class_ty) {
618+
Truthiness::AlwaysTrue
619+
} else if subject_ty.is_disjoint_from(db, class_ty) {
620+
Truthiness::AlwaysFalse
621+
} else {
622+
Truthiness::Ambiguous
623+
}
624+
})
625+
}
626+
PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
627+
}
628+
}
629+
630+
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
631+
let truthiness = Self::analyze_single_pattern_predicate_kind(
632+
db,
633+
predicate.kind(db),
634+
predicate.subject(db),
635+
);
636+
637+
if truthiness == Truthiness::AlwaysTrue && predicate.guard(db).is_some() {
638+
// Fall back to ambiguous, the guard might change the result.
639+
// TODO: actually analyze guard truthiness
640+
Truthiness::Ambiguous
641+
} else {
642+
truthiness
643+
}
644+
}
645+
556646
fn analyze_single(db: &dyn Db, predicate: &Predicate) -> Truthiness {
557647
match predicate.node {
558648
PredicateNode::Expression(test_expr) => {
559649
let ty = infer_expression_type(db, test_expr);
560650
ty.bool(db).negate_if(!predicate.is_positive)
561651
}
562-
PredicateNode::Pattern(inner) => match inner.kind(db) {
563-
PatternPredicateKind::Value(value) => {
564-
let subject_expression = inner.subject(db);
565-
let subject_ty = infer_expression_type(db, subject_expression);
566-
let value_ty = infer_expression_type(db, *value);
567-
568-
if subject_ty.is_single_valued(db) {
569-
let truthiness =
570-
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty));
571-
572-
if truthiness.is_always_true() && inner.guard(db).is_some() {
573-
// Fall back to ambiguous, the guard might change the result.
574-
Truthiness::Ambiguous
575-
} else {
576-
truthiness
577-
}
578-
} else {
579-
Truthiness::Ambiguous
580-
}
581-
}
582-
PatternPredicateKind::Singleton(..)
583-
| PatternPredicateKind::Class(..)
584-
| PatternPredicateKind::Or(..)
585-
| PatternPredicateKind::Unsupported => Truthiness::Ambiguous,
586-
},
652+
PredicateNode::Pattern(inner) => Self::analyze_single_pattern_predicate(db, inner),
587653
}
588654
}
589655
}

0 commit comments

Comments
 (0)