Skip to content

Commit a48d779

Browse files
authored
[red-knot] function signature representation (#14304)
## Summary Add a typed representation of function signatures (parameters and return type) and infer it correctly from a function. Convert existing usage of function return types to use the signature representation. This does not yet add inferred types for parameters within function body scopes based on the annotations, but it should be easy to add as a next step. Part of #14161 and #13693. ## Test Plan Added tests.
1 parent ba6c7f6 commit a48d779

File tree

8 files changed

+559
-67
lines changed

8 files changed

+559
-67
lines changed

crates/red_knot_python_semantic/resources/mdtest/call/function.md

+9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ async def get_int_async() -> int:
1919
reveal_type(get_int_async()) # revealed: @Todo
2020
```
2121

22+
## Generic
23+
24+
```py
25+
def get_int[T]() -> int:
26+
return 42
27+
28+
reveal_type(get_int()) # revealed: int
29+
```
30+
2231
## Decorated
2332

2433
```py

crates/red_knot_python_semantic/resources/mdtest/exception/basic.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ except EXCEPTIONS as f:
4141
## Dynamic exception types
4242

4343
```py
44-
# TODO: we should not emit these `call-possibly-unbound-method` errors for `tuple.__class_getitem__`
4544
def foo(
4645
x: type[AttributeError],
47-
y: tuple[type[OSError], type[RuntimeError]], # error: [call-possibly-unbound-method]
48-
z: tuple[type[BaseException], ...], # error: [call-possibly-unbound-method]
46+
y: tuple[type[OSError], type[RuntimeError]],
47+
z: tuple[type[BaseException], ...],
4948
):
5049
try:
5150
help()

crates/red_knot_python_semantic/resources/mdtest/generics.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -65,31 +65,31 @@ A PEP695 type variable defines a value of type `typing.TypeVar` with attributes
6565

6666
```py
6767
def f[T, U: A, V: (A, B), W = A, X: A = A1]():
68-
reveal_type(T) # revealed: TypeVar
68+
reveal_type(T) # revealed: T
6969
reveal_type(T.__name__) # revealed: Literal["T"]
7070
reveal_type(T.__bound__) # revealed: None
7171
reveal_type(T.__constraints__) # revealed: tuple[()]
7272
reveal_type(T.__default__) # revealed: NoDefault
7373

74-
reveal_type(U) # revealed: TypeVar
74+
reveal_type(U) # revealed: U
7575
reveal_type(U.__name__) # revealed: Literal["U"]
7676
reveal_type(U.__bound__) # revealed: type[A]
7777
reveal_type(U.__constraints__) # revealed: tuple[()]
7878
reveal_type(U.__default__) # revealed: NoDefault
7979

80-
reveal_type(V) # revealed: TypeVar
80+
reveal_type(V) # revealed: V
8181
reveal_type(V.__name__) # revealed: Literal["V"]
8282
reveal_type(V.__bound__) # revealed: None
8383
reveal_type(V.__constraints__) # revealed: tuple[type[A], type[B]]
8484
reveal_type(V.__default__) # revealed: NoDefault
8585

86-
reveal_type(W) # revealed: TypeVar
86+
reveal_type(W) # revealed: W
8787
reveal_type(W.__name__) # revealed: Literal["W"]
8888
reveal_type(W.__bound__) # revealed: None
8989
reveal_type(W.__constraints__) # revealed: tuple[()]
9090
reveal_type(W.__default__) # revealed: type[A]
9191

92-
reveal_type(X) # revealed: TypeVar
92+
reveal_type(X) # revealed: X
9393
reveal_type(X.__name__) # revealed: Literal["X"]
9494
reveal_type(X.__bound__) # revealed: type[A]
9595
reveal_type(X.__constraints__) # revealed: tuple[()]

crates/red_knot_python_semantic/src/types.rs

+50-29
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub(crate) use self::display::TypeArrayDisplay;
1212
pub(crate) use self::infer::{
1313
infer_deferred_types, infer_definition_types, infer_expression_types, infer_scope_types,
1414
};
15+
pub(crate) use self::signatures::Signature;
1516
use crate::module_resolver::file_to_module;
1617
use crate::semantic_index::ast_ids::HasScopedAstId;
1718
use crate::semantic_index::definition::Definition;
@@ -35,6 +36,7 @@ mod display;
3536
mod infer;
3637
mod mro;
3738
mod narrow;
39+
mod signatures;
3840
mod unpacker;
3941

4042
#[salsa::tracked(return_ref)]
@@ -1271,11 +1273,11 @@ impl<'db> Type<'db> {
12711273
Type::FunctionLiteral(function_type) => {
12721274
if function_type.is_known(db, KnownFunction::RevealType) {
12731275
CallOutcome::revealed(
1274-
function_type.return_ty(db),
1276+
function_type.signature(db).return_ty,
12751277
*arg_types.first().unwrap_or(&Type::Unknown),
12761278
)
12771279
} else {
1278-
CallOutcome::callable(function_type.return_ty(db))
1280+
CallOutcome::callable(function_type.signature(db).return_ty)
12791281
}
12801282
}
12811283

@@ -1461,6 +1463,24 @@ impl<'db> Type<'db> {
14611463
}
14621464
}
14631465

1466+
/// If we see a value of this type used as a type expression, what type does it name?
1467+
///
1468+
/// For example, the builtin `int` as a value expression is of type
1469+
/// `Type::ClassLiteral(builtins.int)`, that is, it is the `int` class itself. As a type
1470+
/// expression, it names the type `Type::Instance(builtins.int)`, that is, all objects whose
1471+
/// `__class__` is `int`.
1472+
#[must_use]
1473+
pub fn in_type_expression(&self, db: &'db dyn Db) -> Type<'db> {
1474+
match self {
1475+
Type::ClassLiteral(_) | Type::SubclassOf(_) => self.to_instance(db),
1476+
Type::Union(union) => union.map(db, |element| element.in_type_expression(db)),
1477+
Type::Unknown => Type::Unknown,
1478+
// TODO map this to a new `Type::TypeVar` variant
1479+
Type::KnownInstance(KnownInstanceType::TypeVar(_)) => *self,
1480+
_ => Type::Todo,
1481+
}
1482+
}
1483+
14641484
/// The type `NoneType` / `None`
14651485
pub fn none(db: &'db dyn Db) -> Type<'db> {
14661486
KnownClass::NoneType.to_instance(db)
@@ -2322,7 +2342,10 @@ impl<'db> FunctionType<'db> {
23222342
self.decorators(db).contains(&decorator)
23232343
}
23242344

2325-
/// inferred return type for this function
2345+
/// Typed externally-visible signature for this function.
2346+
///
2347+
/// This is the signature as seen by external callers, possibly modified by decorators and/or
2348+
/// overloaded.
23262349
///
23272350
/// ## Why is this a salsa query?
23282351
///
@@ -2331,34 +2354,32 @@ impl<'db> FunctionType<'db> {
23312354
///
23322355
/// Were this not a salsa query, then the calling query
23332356
/// would depend on the function's AST and rerun for every change in that file.
2334-
#[salsa::tracked]
2335-
pub fn return_ty(self, db: &'db dyn Db) -> Type<'db> {
2357+
#[salsa::tracked(return_ref)]
2358+
pub fn signature(self, db: &'db dyn Db) -> Signature<'db> {
2359+
let function_stmt_node = self.body_scope(db).node(db).expect_function();
2360+
let internal_signature = self.internal_signature(db);
2361+
if function_stmt_node.decorator_list.is_empty() {
2362+
return internal_signature;
2363+
}
2364+
// TODO process the effect of decorators on the signature
2365+
Signature::todo()
2366+
}
2367+
2368+
/// Typed internally-visible signature for this function.
2369+
///
2370+
/// This represents the annotations on the function itself, unmodified by decorators and
2371+
/// overloads.
2372+
///
2373+
/// These are the parameter and return types that should be used for type checking the body of
2374+
/// the function.
2375+
///
2376+
/// Don't call this when checking any other file; only when type-checking the function body
2377+
/// scope.
2378+
fn internal_signature(self, db: &'db dyn Db) -> Signature<'db> {
23362379
let scope = self.body_scope(db);
23372380
let function_stmt_node = scope.node(db).expect_function();
2338-
2339-
// TODO if a function `bar` is decorated by `foo`,
2340-
// where `foo` is annotated as returning a type `X` that is a subtype of `Callable`,
2341-
// we need to infer the return type from `X`'s return annotation
2342-
// rather than from `bar`'s return annotation
2343-
// in order to determine the type that `bar` returns
2344-
if !function_stmt_node.decorator_list.is_empty() {
2345-
return Type::Todo;
2346-
}
2347-
2348-
function_stmt_node
2349-
.returns
2350-
.as_ref()
2351-
.map(|returns| {
2352-
if function_stmt_node.is_async {
2353-
// TODO: generic `types.CoroutineType`!
2354-
Type::Todo
2355-
} else {
2356-
let definition =
2357-
semantic_index(db, scope.file(db)).definition(function_stmt_node);
2358-
definition_expression_ty(db, definition, returns.as_ref())
2359-
}
2360-
})
2361-
.unwrap_or(Type::Unknown)
2381+
let definition = semantic_index(db, scope.file(db)).definition(function_stmt_node);
2382+
Signature::from_function(db, definition, function_stmt_node)
23622383
}
23632384

23642385
pub fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {

crates/red_knot_python_semantic/src/types/display.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ impl Display for DisplayRepresentation<'_> {
8585
Type::SubclassOf(SubclassOfType { class }) => {
8686
write!(f, "type[{}]", class.name(self.db))
8787
}
88-
Type::KnownInstance(known_instance) => f.write_str(known_instance.as_str()),
88+
Type::KnownInstance(known_instance) => f.write_str(known_instance.repr(self.db)),
8989
Type::FunctionLiteral(function) => f.write_str(function.name(self.db)),
9090
Type::Union(union) => union.display(self.db).fmt(f),
9191
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),

crates/red_knot_python_semantic/src/types/infer.rs

+12-28
Original file line numberDiff line numberDiff line change
@@ -822,8 +822,7 @@ impl<'db> TypeInferenceBuilder<'db> {
822822
.as_deref()
823823
.expect("function type params scope without type params");
824824

825-
// TODO: defer annotation resolution in stubs, with __future__.annotations, or stringified
826-
self.infer_optional_expression(function.returns.as_deref());
825+
self.infer_optional_annotation_expression(function.returns.as_deref());
827826
self.infer_type_parameters(type_params);
828827
self.infer_parameters(&function.parameters);
829828
}
@@ -915,13 +914,11 @@ impl<'db> TypeInferenceBuilder<'db> {
915914
// If there are type params, parameters and returns are evaluated in that scope, that is, in
916915
// `infer_function_type_params`, rather than here.
917916
if type_params.is_none() {
918-
self.infer_parameters(parameters);
919-
920-
// TODO: this should also be applied to parameter annotations.
921917
if self.are_all_types_deferred() {
922918
self.types.has_deferred = true;
923919
} else {
924920
self.infer_optional_annotation_expression(returns.as_deref());
921+
self.infer_parameters(parameters);
925922
}
926923
}
927924

@@ -971,7 +968,7 @@ impl<'db> TypeInferenceBuilder<'db> {
971968
default: _,
972969
} = parameter_with_default;
973970

974-
self.infer_optional_expression(parameter.annotation.as_deref());
971+
self.infer_optional_annotation_expression(parameter.annotation.as_deref());
975972
}
976973

977974
fn infer_parameter(&mut self, parameter: &ast::Parameter) {
@@ -981,7 +978,7 @@ impl<'db> TypeInferenceBuilder<'db> {
981978
annotation,
982979
} = parameter;
983980

984-
self.infer_optional_expression(annotation.as_deref());
981+
self.infer_optional_annotation_expression(annotation.as_deref());
985982
}
986983

987984
fn infer_parameter_with_default_definition(
@@ -1069,6 +1066,7 @@ impl<'db> TypeInferenceBuilder<'db> {
10691066

10701067
fn infer_function_deferred(&mut self, function: &ast::StmtFunctionDef) {
10711068
self.infer_optional_annotation_expression(function.returns.as_deref());
1069+
self.infer_parameters(function.parameters.as_ref());
10721070
}
10731071

10741072
fn infer_class_deferred(&mut self, class: &ast::StmtClassDef) {
@@ -4099,15 +4097,17 @@ impl<'db> TypeInferenceBuilder<'db> {
40994097

41004098
match expression {
41014099
ast::Expr::Name(name) => match name.ctx {
4102-
ast::ExprContext::Load => self.infer_name_expression(name).to_instance(self.db),
4100+
ast::ExprContext::Load => {
4101+
self.infer_name_expression(name).in_type_expression(self.db)
4102+
}
41034103
ast::ExprContext::Invalid => Type::Unknown,
41044104
ast::ExprContext::Store | ast::ExprContext::Del => Type::Todo,
41054105
},
41064106

41074107
ast::Expr::Attribute(attribute_expression) => match attribute_expression.ctx {
41084108
ast::ExprContext::Load => self
41094109
.infer_attribute_expression(attribute_expression)
4110-
.to_instance(self.db),
4110+
.in_type_expression(self.db),
41114111
ast::ExprContext::Invalid => Type::Unknown,
41124112
ast::ExprContext::Store | ast::ExprContext::Del => Type::Todo,
41134113
},
@@ -5019,24 +5019,8 @@ mod tests {
50195019
",
50205020
)?;
50215021

5022-
// TODO: sys.version_info, and need to understand @final and @type_check_only
5023-
assert_public_ty(&db, "src/a.py", "x", "EllipsisType | Unknown");
5024-
5025-
Ok(())
5026-
}
5027-
5028-
#[test]
5029-
fn function_return_type() -> anyhow::Result<()> {
5030-
let mut db = setup_db();
5031-
5032-
db.write_file("src/a.py", "def example() -> int: return 42")?;
5033-
5034-
let mod_file = system_path_to_file(&db, "src/a.py").unwrap();
5035-
let function = global_symbol(&db, mod_file, "example")
5036-
.expect_type()
5037-
.expect_function_literal();
5038-
let returns = function.return_ty(&db);
5039-
assert_eq!(returns.display(&db).to_string(), "int");
5022+
// TODO: sys.version_info
5023+
assert_public_ty(&db, "src/a.py", "x", "EllipsisType | ellipsis");
50405024

50415025
Ok(())
50425026
}
@@ -5251,7 +5235,7 @@ mod tests {
52515235
fn deferred_annotations_regular_source_fails() -> anyhow::Result<()> {
52525236
let mut db = setup_db();
52535237

5254-
// In (regular) source files, deferred annotations are *not* resolved
5238+
// In (regular) source files, annotations are *not* deferred
52555239
// Also tests imports from `__future__` that are not annotations
52565240
db.write_dedented(
52575241
"/src/source.py",

0 commit comments

Comments
 (0)