Skip to content

Commit 907b6ed

Browse files
[red-knot] Type narrowing for assertions (#17149)
## Summary Fixes #17147 ## Test Plan Add new narrow/assert.md test file --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent fd9882a commit 907b6ed

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Narrowing with assert statements
2+
3+
## `assert` a value `is None` or `is not None`
4+
5+
```py
6+
def _(x: str | None, y: str | None):
7+
assert x is not None
8+
reveal_type(x) # revealed: str
9+
assert y is None
10+
reveal_type(y) # revealed: None
11+
```
12+
13+
## `assert` a value is truthy or falsy
14+
15+
```py
16+
def _(x: bool, y: bool):
17+
assert x
18+
reveal_type(x) # revealed: Literal[True]
19+
assert not y
20+
reveal_type(y) # revealed: Literal[False]
21+
```
22+
23+
## `assert` with `is` and `==` for literals
24+
25+
```py
26+
from typing import Literal
27+
28+
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
29+
assert x is 2
30+
reveal_type(x) # revealed: Literal[2]
31+
assert y == 2
32+
reveal_type(y) # revealed: Literal[1, 2, 3]
33+
```
34+
35+
## `assert` with `isinstance`
36+
37+
```py
38+
def _(x: int | str):
39+
assert isinstance(x, int)
40+
reveal_type(x) # revealed: int
41+
```
42+
43+
## `assert` a value `in` a tuple
44+
45+
```py
46+
from typing import Literal
47+
48+
def _(x: Literal[1, 2, 3], y: Literal[1, 2, 3]):
49+
assert x in (1, 2)
50+
reveal_type(x) # revealed: Literal[1, 2]
51+
assert y not in (1, 2)
52+
reveal_type(y) # revealed: Literal[3]
53+
```

crates/red_knot_python_semantic/src/semantic_index/builder.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,6 @@ impl<'db> SemanticIndexBuilder<'db> {
534534
}
535535

536536
/// Records a visibility constraint by applying it to all live bindings and declarations.
537-
#[must_use = "A visibility constraint must always be negated after it is added"]
538537
fn record_visibility_constraint(
539538
&mut self,
540539
predicate: Predicate<'db>,
@@ -1292,6 +1291,17 @@ where
12921291
);
12931292
}
12941293
}
1294+
1295+
ast::Stmt::Assert(node) => {
1296+
self.visit_expr(&node.test);
1297+
let predicate = self.record_expression_narrowing_constraint(&node.test);
1298+
self.record_visibility_constraint(predicate);
1299+
1300+
if let Some(msg) = &node.msg {
1301+
self.visit_expr(msg);
1302+
}
1303+
}
1304+
12951305
ast::Stmt::Assign(node) => {
12961306
debug_assert_eq!(&self.current_assignments, &[]);
12971307

crates/red_knot_python_semantic/src/types/infer.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3188,7 +3188,7 @@ impl<'db> TypeInferenceBuilder<'db> {
31883188
msg,
31893189
} = assert;
31903190

3191-
let test_ty = self.infer_expression(test);
3191+
let test_ty = self.infer_standalone_expression(test);
31923192

31933193
if let Err(err) = test_ty.try_bool(self.db()) {
31943194
err.report_diagnostic(&self.context, &**test);

0 commit comments

Comments
 (0)