Skip to content

Commit bd4a947

Browse files
authored
[red-knot] Add symbol and definition for parameters (#12862)
## Summary This PR adds support for adding symbols and definitions for function and lambda parameters to the semantic index. ### Notes * The default expression of a parameter is evaluated in the enclosing scope (not the type parameter or function scope). * The annotation expression of a parameter is evaluated in the type parameter scope if they're present other in the enclosing scope. * The symbols and definitions are added in the function parameter scope. ### Type Inference There are two definitions `Parameter` and `ParameterWithDefault` and their respective `*_definition` methods on the type inference builder. These methods are preferred and are re-used when checking from a different region. ## Test Plan Add test case for validating that the parameters are defined in the function / lambda scope. ### Benchmark update Validated the difference in diagnostics for benchmark code between `main` and this branch. All of them are either directly or indirectly referencing one of the function parameters. The diff is in the PR description.
1 parent f121f8b commit bd4a947

File tree

5 files changed

+227
-5
lines changed

5 files changed

+227
-5
lines changed

crates/red_knot_python_semantic/src/semantic_index.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,103 @@ y = 2
528528
));
529529
}
530530

531+
#[test]
532+
fn function_parameter_symbols() {
533+
let TestCase { db, file } = test_case(
534+
"
535+
def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
536+
pass
537+
",
538+
);
539+
540+
let index = semantic_index(&db, file);
541+
let global_table = symbol_table(&db, global_scope(&db, file));
542+
543+
assert_eq!(names(&global_table), vec!["f", "str", "int"]);
544+
545+
let [(function_scope_id, _function_scope)] = index
546+
.child_scopes(FileScopeId::global())
547+
.collect::<Vec<_>>()[..]
548+
else {
549+
panic!("Expected a function scope")
550+
};
551+
552+
let function_table = index.symbol_table(function_scope_id);
553+
assert_eq!(
554+
names(&function_table),
555+
vec!["a", "b", "c", "args", "d", "kwargs"],
556+
);
557+
558+
let use_def = index.use_def_map(function_scope_id);
559+
for name in ["a", "b", "c", "d"] {
560+
let [definition] = use_def.public_definitions(
561+
function_table
562+
.symbol_id_by_name(name)
563+
.expect("symbol exists"),
564+
) else {
565+
panic!("Expected parameter definition for {name}");
566+
};
567+
assert!(matches!(
568+
definition.node(&db),
569+
DefinitionKind::ParameterWithDefault(_)
570+
));
571+
}
572+
for name in ["args", "kwargs"] {
573+
let [definition] = use_def.public_definitions(
574+
function_table
575+
.symbol_id_by_name(name)
576+
.expect("symbol exists"),
577+
) else {
578+
panic!("Expected parameter definition for {name}");
579+
};
580+
assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_)));
581+
}
582+
}
583+
584+
#[test]
585+
fn lambda_parameter_symbols() {
586+
let TestCase { db, file } = test_case("lambda a, b, c=1, *args, d=2, **kwargs: None");
587+
588+
let index = semantic_index(&db, file);
589+
let global_table = symbol_table(&db, global_scope(&db, file));
590+
591+
assert!(names(&global_table).is_empty());
592+
593+
let [(lambda_scope_id, _lambda_scope)] = index
594+
.child_scopes(FileScopeId::global())
595+
.collect::<Vec<_>>()[..]
596+
else {
597+
panic!("Expected a lambda scope")
598+
};
599+
600+
let lambda_table = index.symbol_table(lambda_scope_id);
601+
assert_eq!(
602+
names(&lambda_table),
603+
vec!["a", "b", "c", "args", "d", "kwargs"],
604+
);
605+
606+
let use_def = index.use_def_map(lambda_scope_id);
607+
for name in ["a", "b", "c", "d"] {
608+
let [definition] = use_def
609+
.public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
610+
else {
611+
panic!("Expected parameter definition for {name}");
612+
};
613+
assert!(matches!(
614+
definition.node(&db),
615+
DefinitionKind::ParameterWithDefault(_)
616+
));
617+
}
618+
for name in ["args", "kwargs"] {
619+
let [definition] = use_def
620+
.public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists"))
621+
else {
622+
panic!("Expected parameter definition for {name}");
623+
};
624+
assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_)));
625+
}
626+
}
627+
531628
/// Test case to validate that the comprehension scope is correctly identified and that the target
532629
/// variable is defined only in the comprehension scope and not in the global scope.
533630
#[test]

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,16 @@ where
368368
.add_or_update_symbol(function_def.name.id.clone(), SymbolFlags::IS_DEFINED);
369369
self.add_definition(symbol, function_def);
370370

371+
// The default value of the parameters needs to be evaluated in the
372+
// enclosing scope.
373+
for default in function_def
374+
.parameters
375+
.iter_non_variadic_params()
376+
.filter_map(|param| param.default.as_deref())
377+
{
378+
self.visit_expr(default);
379+
}
380+
371381
self.with_type_params(
372382
NodeWithScopeRef::FunctionTypeParameters(function_def),
373383
function_def.type_params.as_deref(),
@@ -378,6 +388,16 @@ where
378388
}
379389

380390
builder.push_scope(NodeWithScopeRef::Function(function_def));
391+
392+
// Add symbols and definitions for the parameters to the function scope.
393+
for parameter in &*function_def.parameters {
394+
let symbol = builder.add_or_update_symbol(
395+
parameter.name().id().clone(),
396+
SymbolFlags::IS_DEFINED,
397+
);
398+
builder.add_definition(symbol, parameter);
399+
}
400+
381401
builder.visit_body(&function_def.body);
382402
builder.pop_scope()
383403
},
@@ -574,9 +594,29 @@ where
574594
}
575595
ast::Expr::Lambda(lambda) => {
576596
if let Some(parameters) = &lambda.parameters {
597+
// The default value of the parameters needs to be evaluated in the
598+
// enclosing scope.
599+
for default in parameters
600+
.iter_non_variadic_params()
601+
.filter_map(|param| param.default.as_deref())
602+
{
603+
self.visit_expr(default);
604+
}
577605
self.visit_parameters(parameters);
578606
}
579607
self.push_scope(NodeWithScopeRef::Lambda(lambda));
608+
609+
// Add symbols and definitions for the parameters to the lambda scope.
610+
if let Some(parameters) = &lambda.parameters {
611+
for parameter in &**parameters {
612+
let symbol = self.add_or_update_symbol(
613+
parameter.name().id().clone(),
614+
SymbolFlags::IS_DEFINED,
615+
);
616+
self.add_definition(symbol, parameter);
617+
}
618+
}
619+
580620
self.visit_expr(lambda.body.as_ref());
581621
}
582622
ast::Expr::If(ast::ExprIf {
@@ -654,6 +694,14 @@ where
654694
self.pop_scope();
655695
}
656696
}
697+
698+
fn visit_parameters(&mut self, parameters: &'ast ruff_python_ast::Parameters) {
699+
// Intentionally avoid walking default expressions, as we handle them in the enclosing
700+
// scope.
701+
for parameter in parameters.iter().map(ast::AnyParameterRef::as_parameter) {
702+
self.visit_parameter(parameter);
703+
}
704+
}
657705
}
658706

659707
#[derive(Copy, Clone, Debug)]

crates/red_knot_python_semantic/src/semantic_index/definition.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
4545
Assignment(AssignmentDefinitionNodeRef<'a>),
4646
AnnotatedAssignment(&'a ast::StmtAnnAssign),
4747
Comprehension(ComprehensionDefinitionNodeRef<'a>),
48+
Parameter(ast::AnyParameterRef<'a>),
4849
}
4950

5051
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
@@ -95,6 +96,12 @@ impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
9596
}
9697
}
9798

99+
impl<'a> From<ast::AnyParameterRef<'a>> for DefinitionNodeRef<'a> {
100+
fn from(node: ast::AnyParameterRef<'a>) -> Self {
101+
Self::Parameter(node)
102+
}
103+
}
104+
98105
#[derive(Copy, Clone, Debug)]
99106
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
100107
pub(crate) node: &'a ast::StmtImportFrom,
@@ -150,6 +157,14 @@ impl DefinitionNodeRef<'_> {
150157
first,
151158
})
152159
}
160+
DefinitionNodeRef::Parameter(parameter) => match parameter {
161+
ast::AnyParameterRef::Variadic(parameter) => {
162+
DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter))
163+
}
164+
ast::AnyParameterRef::NonVariadic(parameter) => {
165+
DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter))
166+
}
167+
},
153168
}
154169
}
155170

@@ -168,6 +183,10 @@ impl DefinitionNodeRef<'_> {
168183
}) => target.into(),
169184
Self::AnnotatedAssignment(node) => node.into(),
170185
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
186+
Self::Parameter(node) => match node {
187+
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
188+
ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(),
189+
},
171190
}
172191
}
173192
}
@@ -182,6 +201,8 @@ pub enum DefinitionKind {
182201
Assignment(AssignmentDefinitionKind),
183202
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
184203
Comprehension(ComprehensionDefinitionKind),
204+
Parameter(AstNodeRef<ast::Parameter>),
205+
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
185206
}
186207

187208
#[derive(Clone, Debug)]
@@ -273,3 +294,15 @@ impl From<&ast::Comprehension> for DefinitionNodeKey {
273294
Self(NodeKey::from_node(node))
274295
}
275296
}
297+
298+
impl From<&ast::Parameter> for DefinitionNodeKey {
299+
fn from(node: &ast::Parameter) -> Self {
300+
Self(NodeKey::from_node(node))
301+
}
302+
}
303+
304+
impl From<&ast::ParameterWithDefault> for DefinitionNodeKey {
305+
fn from(node: &ast::ParameterWithDefault) -> Self {
306+
Self(NodeKey::from_node(node))
307+
}
308+
}

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ impl<'db> TypeInferenceBuilder<'db> {
307307
definition,
308308
);
309309
}
310+
DefinitionKind::Parameter(parameter) => {
311+
self.infer_parameter_definition(parameter, definition);
312+
}
313+
DefinitionKind::ParameterWithDefault(parameter_with_default) => {
314+
self.infer_parameter_with_default_definition(parameter_with_default, definition);
315+
}
310316
}
311317
}
312318

@@ -421,6 +427,13 @@ impl<'db> TypeInferenceBuilder<'db> {
421427
.map(|decorator| self.infer_decorator(decorator))
422428
.collect();
423429

430+
for default in parameters
431+
.iter_non_variadic_params()
432+
.filter_map(|param| param.default.as_deref())
433+
{
434+
self.infer_expression(default);
435+
}
436+
424437
// If there are type params, parameters and returns are evaluated in that scope.
425438
if type_params.is_none() {
426439
self.infer_parameters(parameters);
@@ -458,10 +471,12 @@ impl<'db> TypeInferenceBuilder<'db> {
458471
let ast::ParameterWithDefault {
459472
range: _,
460473
parameter,
461-
default,
474+
default: _,
462475
} = parameter_with_default;
463-
self.infer_parameter(parameter);
464-
self.infer_optional_expression(default.as_deref());
476+
477+
self.infer_optional_expression(parameter.annotation.as_deref());
478+
479+
self.infer_definition(parameter_with_default);
465480
}
466481

467482
fn infer_parameter(&mut self, parameter: &ast::Parameter) {
@@ -470,7 +485,29 @@ impl<'db> TypeInferenceBuilder<'db> {
470485
name: _,
471486
annotation,
472487
} = parameter;
488+
473489
self.infer_optional_expression(annotation.as_deref());
490+
491+
self.infer_definition(parameter);
492+
}
493+
494+
fn infer_parameter_with_default_definition(
495+
&mut self,
496+
_parameter_with_default: &ast::ParameterWithDefault,
497+
definition: Definition<'db>,
498+
) {
499+
// TODO(dhruvmanila): Infer types from annotation or default expression
500+
self.types.definitions.insert(definition, Type::Unknown);
501+
}
502+
503+
fn infer_parameter_definition(
504+
&mut self,
505+
_parameter: &ast::Parameter,
506+
definition: Definition<'db>,
507+
) {
508+
// TODO(dhruvmanila): Annotation expression is resolved at the enclosing scope, infer the
509+
// parameter type from there
510+
self.types.definitions.insert(definition, Type::Unknown);
474511
}
475512

476513
fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) {
@@ -1277,6 +1314,13 @@ impl<'db> TypeInferenceBuilder<'db> {
12771314
} = lambda_expression;
12781315

12791316
if let Some(parameters) = parameters {
1317+
for default in parameters
1318+
.iter_non_variadic_params()
1319+
.filter_map(|param| param.default.as_deref())
1320+
{
1321+
self.infer_expression(default);
1322+
}
1323+
12801324
self.infer_parameters(parameters);
12811325
}
12821326

crates/ruff_benchmark/benches/red_knot.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ fn benchmark_incremental(criterion: &mut Criterion) {
8989
let Case { db, parser, .. } = case;
9090
let result = db.check_file(*parser).unwrap();
9191

92-
assert_eq!(result.len(), 402);
92+
assert_eq!(result.len(), 111);
9393
},
9494
BatchSize::SmallInput,
9595
);
@@ -104,7 +104,7 @@ fn benchmark_cold(criterion: &mut Criterion) {
104104
let Case { db, parser, .. } = case;
105105
let result = db.check_file(*parser).unwrap();
106106

107-
assert_eq!(result.len(), 402);
107+
assert_eq!(result.len(), 111);
108108
},
109109
BatchSize::SmallInput,
110110
);

0 commit comments

Comments
 (0)