Skip to content

Commit c173ec5

Browse files
authored
[red-knot] support for typing.reveal_type (#13384)
Add support for the `typing.reveal_type` function, emitting a diagnostic revealing the type of its single argument. This is a necessary piece for the planned testing framework. This puts the cart slightly in front of the horse, in that we don't yet have proper support for validating call signatures / argument types. But it's easy to do just enough to make `reveal_type` work. This PR includes support for calling union types (this is necessary because we don't yet support `sys.version_info` checks, so `typing.reveal_type` itself is a union type), plus some nice consolidated error messages for calls to unions where some elements are not callable. This is mostly to demonstrate the flexibility in diagnostics that we get from the `CallOutcome` enum.
1 parent 44d916f commit c173ec5

File tree

3 files changed

+384
-51
lines changed

3 files changed

+384
-51
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 209 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ pub enum Type<'db> {
238238
None,
239239
/// a specific function object
240240
Function(FunctionType<'db>),
241+
/// The `typing.reveal_type` function, which has special `__call__` behavior.
242+
RevealTypeFunction(FunctionType<'db>),
241243
/// a specific module object
242244
Module(File),
243245
/// a specific class object
@@ -324,14 +326,16 @@ impl<'db> Type<'db> {
324326

325327
pub const fn into_function_type(self) -> Option<FunctionType<'db>> {
326328
match self {
327-
Type::Function(function_type) => Some(function_type),
329+
Type::Function(function_type) | Type::RevealTypeFunction(function_type) => {
330+
Some(function_type)
331+
}
328332
_ => None,
329333
}
330334
}
331335

332336
pub fn expect_function(self) -> FunctionType<'db> {
333337
self.into_function_type()
334-
.expect("Expected a Type::Function variant")
338+
.expect("Expected a variant wrapping a FunctionType")
335339
}
336340

337341
pub const fn into_int_literal_type(self) -> Option<i64> {
@@ -367,6 +371,16 @@ impl<'db> Type<'db> {
367371
}
368372
}
369373

374+
pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
375+
match self {
376+
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
377+
Type::Function(function) | Type::RevealTypeFunction(function) => {
378+
function.is_stdlib_symbol(db, module_name, name)
379+
}
380+
_ => false,
381+
}
382+
}
383+
370384
/// Return true if this type is [assignable to] type `target`.
371385
///
372386
/// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation
@@ -436,7 +450,7 @@ impl<'db> Type<'db> {
436450
// TODO: attribute lookup on None type
437451
Type::Unknown
438452
}
439-
Type::Function(_) => {
453+
Type::Function(_) | Type::RevealTypeFunction(_) => {
440454
// TODO: attribute lookup on function type
441455
Type::Unknown
442456
}
@@ -482,26 +496,39 @@ impl<'db> Type<'db> {
482496
///
483497
/// Returns `None` if `self` is not a callable type.
484498
#[must_use]
485-
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
499+
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
486500
match self {
487-
Type::Function(function_type) => Some(function_type.return_type(db)),
501+
// TODO validate typed call arguments vs callable signature
502+
Type::Function(function_type) => CallOutcome::callable(function_type.return_type(db)),
503+
Type::RevealTypeFunction(function_type) => CallOutcome::revealed(
504+
function_type.return_type(db),
505+
*arg_types.first().unwrap_or(&Type::Unknown),
506+
),
488507

489508
// TODO annotated return type on `__new__` or metaclass `__call__`
490-
Type::Class(class) => Some(Type::Instance(*class)),
509+
Type::Class(class) => CallOutcome::callable(Type::Instance(class)),
491510

492-
// TODO: handle classes which implement the Callable protocol
493-
Type::Instance(_instance_ty) => Some(Type::Unknown),
511+
// TODO: handle classes which implement the `__call__` protocol
512+
Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown),
494513

495514
// `Any` is callable, and its return type is also `Any`.
496-
Type::Any => Some(Type::Any),
515+
Type::Any => CallOutcome::callable(Type::Any),
497516

498-
Type::Unknown => Some(Type::Unknown),
517+
Type::Unknown => CallOutcome::callable(Type::Unknown),
499518

500-
// TODO: union and intersection types, if they reduce to `Callable`
501-
Type::Union(_) => Some(Type::Unknown),
502-
Type::Intersection(_) => Some(Type::Unknown),
519+
Type::Union(union) => CallOutcome::union(
520+
self,
521+
union
522+
.elements(db)
523+
.iter()
524+
.map(|elem| elem.call(db, arg_types))
525+
.collect::<Box<[CallOutcome<'db>]>>(),
526+
),
503527

504-
_ => None,
528+
// TODO: intersection types
529+
Type::Intersection(_) => CallOutcome::callable(Type::Unknown),
530+
531+
_ => CallOutcome::not_callable(self),
505532
}
506533
}
507534

@@ -513,7 +540,7 @@ impl<'db> Type<'db> {
513540
/// for y in x:
514541
/// pass
515542
/// ```
516-
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
543+
fn iterate(self, db: &'db dyn Db) -> IterationOutcome<'db> {
517544
if let Type::Tuple(tuple_type) = self {
518545
return IterationOutcome::Iterable {
519546
element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)),
@@ -526,18 +553,22 @@ impl<'db> Type<'db> {
526553

527554
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
528555
if !dunder_iter_method.is_unbound() {
529-
let Some(iterator_ty) = dunder_iter_method.call(db) else {
556+
let CallOutcome::Callable {
557+
return_ty: iterator_ty,
558+
} = dunder_iter_method.call(db, &[])
559+
else {
530560
return IterationOutcome::NotIterable {
531-
not_iterable_ty: *self,
561+
not_iterable_ty: self,
532562
};
533563
};
534564

535565
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
536566
return dunder_next_method
537-
.call(db)
567+
.call(db, &[])
568+
.return_ty(db)
538569
.map(|element_ty| IterationOutcome::Iterable { element_ty })
539570
.unwrap_or(IterationOutcome::NotIterable {
540-
not_iterable_ty: *self,
571+
not_iterable_ty: self,
541572
});
542573
}
543574

@@ -550,10 +581,11 @@ impl<'db> Type<'db> {
550581
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
551582

552583
dunder_get_item_method
553-
.call(db)
584+
.call(db, &[])
585+
.return_ty(db)
554586
.map(|element_ty| IterationOutcome::Iterable { element_ty })
555587
.unwrap_or(IterationOutcome::NotIterable {
556-
not_iterable_ty: *self,
588+
not_iterable_ty: self,
557589
})
558590
}
559591

@@ -573,6 +605,7 @@ impl<'db> Type<'db> {
573605
Type::BooleanLiteral(_)
574606
| Type::BytesLiteral(_)
575607
| Type::Function(_)
608+
| Type::RevealTypeFunction(_)
576609
| Type::Instance(_)
577610
| Type::Module(_)
578611
| Type::IntLiteral(_)
@@ -595,7 +628,7 @@ impl<'db> Type<'db> {
595628
Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"),
596629
Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"),
597630
Type::IntLiteral(_) => builtins_symbol_ty(db, "int"),
598-
Type::Function(_) => types_symbol_ty(db, "FunctionType"),
631+
Type::Function(_) | Type::RevealTypeFunction(_) => types_symbol_ty(db, "FunctionType"),
599632
Type::Module(_) => types_symbol_ty(db, "ModuleType"),
600633
Type::None => typeshed_symbol_ty(db, "NoneType"),
601634
// TODO not accurate if there's a custom metaclass...
@@ -619,6 +652,152 @@ impl<'db> From<&Type<'db>> for Type<'db> {
619652
}
620653
}
621654

655+
#[derive(Debug, Clone, PartialEq, Eq)]
656+
enum CallOutcome<'db> {
657+
Callable {
658+
return_ty: Type<'db>,
659+
},
660+
RevealType {
661+
return_ty: Type<'db>,
662+
revealed_ty: Type<'db>,
663+
},
664+
NotCallable {
665+
not_callable_ty: Type<'db>,
666+
},
667+
Union {
668+
called_ty: Type<'db>,
669+
outcomes: Box<[CallOutcome<'db>]>,
670+
},
671+
}
672+
673+
impl<'db> CallOutcome<'db> {
674+
/// Create a new `CallOutcome::Callable` with given return type.
675+
fn callable(return_ty: Type<'db>) -> CallOutcome {
676+
CallOutcome::Callable { return_ty }
677+
}
678+
679+
/// Create a new `CallOutcome::NotCallable` with given not-callable type.
680+
fn not_callable(not_callable_ty: Type<'db>) -> CallOutcome {
681+
CallOutcome::NotCallable { not_callable_ty }
682+
}
683+
684+
/// Create a new `CallOutcome::RevealType` with given revealed and return types.
685+
fn revealed(return_ty: Type<'db>, revealed_ty: Type<'db>) -> CallOutcome<'db> {
686+
CallOutcome::RevealType {
687+
return_ty,
688+
revealed_ty,
689+
}
690+
}
691+
692+
/// Create a new `CallOutcome::Union` with given wrapped outcomes.
693+
fn union(called_ty: Type<'db>, outcomes: impl Into<Box<[CallOutcome<'db>]>>) -> CallOutcome {
694+
CallOutcome::Union {
695+
called_ty,
696+
outcomes: outcomes.into(),
697+
}
698+
}
699+
700+
/// Get the return type of the call, or `None` if not callable.
701+
fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> {
702+
match self {
703+
Self::Callable { return_ty } => Some(*return_ty),
704+
Self::RevealType {
705+
return_ty,
706+
revealed_ty: _,
707+
} => Some(*return_ty),
708+
Self::NotCallable { not_callable_ty: _ } => None,
709+
Self::Union {
710+
outcomes,
711+
called_ty: _,
712+
} => outcomes
713+
.iter()
714+
// If all outcomes are NotCallable, we return None; if some outcomes are callable
715+
// and some are not, we return a union including Unknown.
716+
.fold(None, |acc, outcome| {
717+
let ty = outcome.return_ty(db);
718+
match (acc, ty) {
719+
(None, None) => None,
720+
(None, Some(ty)) => Some(UnionBuilder::new(db).add(ty)),
721+
(Some(builder), ty) => Some(builder.add(ty.unwrap_or(Type::Unknown))),
722+
}
723+
})
724+
.map(UnionBuilder::build),
725+
}
726+
}
727+
728+
/// Get the return type of the call, emitting diagnostics if needed.
729+
fn unwrap_with_diagnostic<'a>(
730+
&self,
731+
db: &'db dyn Db,
732+
node: ast::AnyNodeRef,
733+
builder: &'a mut TypeInferenceBuilder<'db>,
734+
) -> Type<'db> {
735+
match self {
736+
Self::Callable { return_ty } => *return_ty,
737+
Self::RevealType {
738+
return_ty,
739+
revealed_ty,
740+
} => {
741+
builder.add_diagnostic(
742+
node,
743+
"revealed-type",
744+
format_args!("Revealed type is '{}'.", revealed_ty.display(db)),
745+
);
746+
*return_ty
747+
}
748+
Self::NotCallable { not_callable_ty } => {
749+
builder.add_diagnostic(
750+
node,
751+
"call-non-callable",
752+
format_args!(
753+
"Object of type '{}' is not callable.",
754+
not_callable_ty.display(db)
755+
),
756+
);
757+
Type::Unknown
758+
}
759+
Self::Union {
760+
outcomes,
761+
called_ty,
762+
} => {
763+
let mut not_callable = vec![];
764+
let mut union_builder = UnionBuilder::new(db);
765+
for outcome in &**outcomes {
766+
let return_ty = if let Self::NotCallable { not_callable_ty } = outcome {
767+
not_callable.push(*not_callable_ty);
768+
Type::Unknown
769+
} else {
770+
outcome.unwrap_with_diagnostic(db, node, builder)
771+
};
772+
union_builder = union_builder.add(return_ty);
773+
}
774+
match not_callable[..] {
775+
[] => {}
776+
[elem] => builder.add_diagnostic(
777+
node,
778+
"call-non-callable",
779+
format_args!(
780+
"Union element '{}' of type '{}' is not callable.",
781+
elem.display(db),
782+
called_ty.display(db)
783+
),
784+
),
785+
_ => builder.add_diagnostic(
786+
node,
787+
"call-non-callable",
788+
format_args!(
789+
"Union elements {} of type '{}' are not callable.",
790+
not_callable.display(db),
791+
called_ty.display(db)
792+
),
793+
),
794+
}
795+
union_builder.build()
796+
}
797+
}
798+
}
799+
}
800+
622801
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
623802
enum IterationOutcome<'db> {
624803
Iterable { element_ty: Type<'db> },
@@ -654,6 +833,14 @@ pub struct FunctionType<'db> {
654833
}
655834

656835
impl<'db> FunctionType<'db> {
836+
/// Return true if this is a standard library function with given module name and name.
837+
pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
838+
name == self.name(db)
839+
&& file_to_module(db, self.definition(db).file(db)).is_some_and(|module| {
840+
module.search_path().is_standard_library() && module.name() == module_name
841+
})
842+
}
843+
657844
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
658845
self.decorators(db).contains(&decorator)
659846
}

crates/red_knot_python_semantic/src/types/display.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ impl Display for DisplayType<'_> {
3636
| Type::BytesLiteral(_)
3737
| Type::Class(_)
3838
| Type::Function(_)
39+
| Type::RevealTypeFunction(_)
3940
) {
4041
write!(f, "Literal[{representation}]",)
4142
} else {
@@ -72,7 +73,9 @@ impl Display for DisplayRepresentation<'_> {
7273
// TODO functions and classes should display using a fully qualified name
7374
Type::Class(class) => f.write_str(class.name(self.db)),
7475
Type::Instance(class) => f.write_str(class.name(self.db)),
75-
Type::Function(function) => f.write_str(function.name(self.db)),
76+
Type::Function(function) | Type::RevealTypeFunction(function) => {
77+
f.write_str(function.name(self.db))
78+
}
7679
Type::Union(union) => union.display(self.db).fmt(f),
7780
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
7881
Type::IntLiteral(n) => n.fmt(f),
@@ -191,7 +194,7 @@ impl TryFrom<Type<'_>> for LiteralTypeKind {
191194
fn try_from(value: Type<'_>) -> Result<Self, Self::Error> {
192195
match value {
193196
Type::Class(_) => Ok(Self::Class),
194-
Type::Function(_) => Ok(Self::Function),
197+
Type::Function(_) | Type::RevealTypeFunction(_) => Ok(Self::Function),
195198
Type::IntLiteral(_) => Ok(Self::IntLiteral),
196199
Type::StringLiteral(_) => Ok(Self::StringLiteral),
197200
Type::BytesLiteral(_) => Ok(Self::BytesLiteral),

0 commit comments

Comments
 (0)