Skip to content

Commit 8144a11

Browse files
authored
[red-knot] Add definition for with items (#12920)
## Summary This PR adds symbols and definitions introduced by `with` statements. The symbols and definitions are introduced for each with item. The type inference is updated to call the definition region type inference instead. ## Test Plan Add test case to check for symbol table and definitions.
1 parent dce87c2 commit 8144a11

File tree

4 files changed

+151
-3
lines changed

4 files changed

+151
-3
lines changed

crates/red_knot_python_semantic/src/semantic_index.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,56 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
790790
assert_eq!(names(&inner_comprehension_symbol_table), vec!["x"]);
791791
}
792792

793+
#[test]
794+
fn with_item_definition() {
795+
let TestCase { db, file } = test_case(
796+
"
797+
with item1 as x, item2 as y:
798+
pass
799+
",
800+
);
801+
802+
let index = semantic_index(&db, file);
803+
let global_table = index.symbol_table(FileScopeId::global());
804+
805+
assert_eq!(names(&global_table), vec!["item1", "x", "item2", "y"]);
806+
807+
let use_def = index.use_def_map(FileScopeId::global());
808+
for name in ["x", "y"] {
809+
let Some(definition) = use_def.first_public_definition(
810+
global_table.symbol_id_by_name(name).expect("symbol exists"),
811+
) else {
812+
panic!("Expected with item definition for {name}");
813+
};
814+
assert!(matches!(definition.node(&db), DefinitionKind::WithItem(_)));
815+
}
816+
}
817+
818+
#[test]
819+
fn with_item_unpacked_definition() {
820+
let TestCase { db, file } = test_case(
821+
"
822+
with context() as (x, y):
823+
pass
824+
",
825+
);
826+
827+
let index = semantic_index(&db, file);
828+
let global_table = index.symbol_table(FileScopeId::global());
829+
830+
assert_eq!(names(&global_table), vec!["context", "x", "y"]);
831+
832+
let use_def = index.use_def_map(FileScopeId::global());
833+
for name in ["x", "y"] {
834+
let Some(definition) = use_def.first_public_definition(
835+
global_table.symbol_id_by_name(name).expect("symbol exists"),
836+
) else {
837+
panic!("Expected with item definition for {name}");
838+
};
839+
assert!(matches!(definition.node(&db), DefinitionKind::WithItem(_)));
840+
}
841+
}
842+
793843
#[test]
794844
fn dupes() {
795845
let TestCase { db, file } = test_case(

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
2626
use crate::semantic_index::SemanticIndex;
2727
use crate::Db;
2828

29+
use super::definition::WithItemDefinitionNodeRef;
30+
2931
pub(super) struct SemanticIndexBuilder<'db> {
3032
// Builder state
3133
db: &'db dyn Db,
@@ -561,6 +563,18 @@ where
561563
self.flow_merge(break_state);
562564
}
563565
}
566+
ast::Stmt::With(ast::StmtWith { items, body, .. }) => {
567+
for item in items {
568+
self.visit_expr(&item.context_expr);
569+
if let Some(optional_vars) = item.optional_vars.as_deref() {
570+
self.add_standalone_expression(&item.context_expr);
571+
self.current_assignment = Some(item.into());
572+
self.visit_expr(optional_vars);
573+
self.current_assignment = None;
574+
}
575+
}
576+
self.visit_body(body);
577+
}
564578
ast::Stmt::Break(_) => {
565579
self.loop_break_states.push(self.flow_snapshot());
566580
}
@@ -622,6 +636,15 @@ where
622636
ComprehensionDefinitionNodeRef { node, first },
623637
);
624638
}
639+
Some(CurrentAssignment::WithItem(with_item)) => {
640+
self.add_definition(
641+
symbol,
642+
WithItemDefinitionNodeRef {
643+
node: with_item,
644+
target: name_node,
645+
},
646+
);
647+
}
625648
None => {}
626649
}
627650
}
@@ -778,6 +801,7 @@ enum CurrentAssignment<'a> {
778801
node: &'a ast::Comprehension,
779802
first: bool,
780803
},
804+
WithItem(&'a ast::WithItem),
781805
}
782806

783807
impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> {
@@ -803,3 +827,9 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
803827
Self::Named(value)
804828
}
805829
}
830+
831+
impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> {
832+
fn from(value: &'a ast::WithItem) -> Self {
833+
Self::WithItem(value)
834+
}
835+
}

crates/red_knot_python_semantic/src/semantic_index/definition.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
4747
AugmentedAssignment(&'a ast::StmtAugAssign),
4848
Comprehension(ComprehensionDefinitionNodeRef<'a>),
4949
Parameter(ast::AnyParameterRef<'a>),
50+
WithItem(WithItemDefinitionNodeRef<'a>),
5051
}
5152

5253
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
@@ -97,6 +98,12 @@ impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
9798
}
9899
}
99100

101+
impl<'a> From<WithItemDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
102+
fn from(node_ref: WithItemDefinitionNodeRef<'a>) -> Self {
103+
Self::WithItem(node_ref)
104+
}
105+
}
106+
100107
impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
101108
fn from(node: ComprehensionDefinitionNodeRef<'a>) -> Self {
102109
Self::Comprehension(node)
@@ -121,6 +128,12 @@ pub(crate) struct AssignmentDefinitionNodeRef<'a> {
121128
pub(crate) target: &'a ast::ExprName,
122129
}
123130

131+
#[derive(Copy, Clone, Debug)]
132+
pub(crate) struct WithItemDefinitionNodeRef<'a> {
133+
pub(crate) node: &'a ast::WithItem,
134+
pub(crate) target: &'a ast::ExprName,
135+
}
136+
124137
#[derive(Copy, Clone, Debug)]
125138
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
126139
pub(crate) node: &'a ast::Comprehension,
@@ -175,6 +188,12 @@ impl DefinitionNodeRef<'_> {
175188
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
176189
}
177190
},
191+
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef { node, target }) => {
192+
DefinitionKind::WithItem(WithItemDefinitionKind {
193+
node: AstNodeRef::new(parsed.clone(), node),
194+
target: AstNodeRef::new(parsed, target),
195+
})
196+
}
178197
}
179198
}
180199

@@ -198,6 +217,7 @@ impl DefinitionNodeRef<'_> {
198217
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
199218
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
200219
},
220+
Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(),
201221
}
202222
}
203223
}
@@ -215,6 +235,7 @@ pub enum DefinitionKind {
215235
Comprehension(ComprehensionDefinitionKind),
216236
Parameter(AstNodeRef<ast::Parameter>),
217237
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
238+
WithItem(WithItemDefinitionKind),
218239
}
219240

220241
#[derive(Clone, Debug)]
@@ -250,7 +271,6 @@ impl ImportFromDefinitionKind {
250271
}
251272

252273
#[derive(Clone, Debug)]
253-
#[allow(dead_code)]
254274
pub struct AssignmentDefinitionKind {
255275
assignment: AstNodeRef<ast::StmtAssign>,
256276
target: AstNodeRef<ast::ExprName>,
@@ -266,6 +286,22 @@ impl AssignmentDefinitionKind {
266286
}
267287
}
268288

289+
#[derive(Clone, Debug)]
290+
pub struct WithItemDefinitionKind {
291+
node: AstNodeRef<ast::WithItem>,
292+
target: AstNodeRef<ast::ExprName>,
293+
}
294+
295+
impl WithItemDefinitionKind {
296+
pub(crate) fn node(&self) -> &ast::WithItem {
297+
self.node.node()
298+
}
299+
300+
pub(crate) fn target(&self) -> &ast::ExprName {
301+
self.target.node()
302+
}
303+
}
304+
269305
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
270306
pub(crate) struct DefinitionNodeKey(NodeKey);
271307

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ impl<'db> TypeInferenceBuilder<'db> {
333333
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
334334
self.infer_parameter_with_default_definition(parameter_with_default, definition);
335335
}
336+
DefinitionKind::WithItem(with_item) => {
337+
self.infer_with_item_definition(with_item.target(), with_item.node(), definition);
338+
}
336339
}
337340
}
338341

@@ -618,13 +621,42 @@ impl<'db> TypeInferenceBuilder<'db> {
618621
} = with_statement;
619622

620623
for item in items {
621-
self.infer_expression(&item.context_expr);
622-
self.infer_optional_expression(item.optional_vars.as_deref());
624+
match item.optional_vars.as_deref() {
625+
Some(ast::Expr::Name(name)) => {
626+
self.infer_definition(name);
627+
}
628+
_ => {
629+
// TODO infer definitions in unpacking assignment
630+
self.infer_expression(&item.context_expr);
631+
}
632+
}
623633
}
624634

625635
self.infer_body(body);
626636
}
627637

638+
fn infer_with_item_definition(
639+
&mut self,
640+
target: &ast::ExprName,
641+
with_item: &ast::WithItem,
642+
definition: Definition<'db>,
643+
) {
644+
let expression = self.index.expression(&with_item.context_expr);
645+
let result = infer_expression_types(self.db, expression);
646+
self.extend(result);
647+
648+
// TODO(dhruvmanila): The correct type inference here is the return type of the __enter__
649+
// method of the context manager.
650+
let context_expr_ty = self
651+
.types
652+
.expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope));
653+
654+
self.types
655+
.expressions
656+
.insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty);
657+
self.types.definitions.insert(definition, context_expr_ty);
658+
}
659+
628660
fn infer_match_statement(&mut self, match_statement: &ast::StmtMatch) {
629661
let ast::StmtMatch {
630662
range: _,

0 commit comments

Comments
 (0)