Skip to content

Commit 7aae809

Browse files
carljmAlexWaygood
andauthored
[red-knot] add support for typing_extensions.reveal_type (#13397)
Before `typing.reveal_type` existed, there was `typing_extensions.reveal_type`. We should support both. Also adds a test to verify that we can handle aliasing of `reveal_type` to a different name. Adds a bit of code to ensure that if we have a union of different `reveal_type` functions (e.g. a union containing both `typing_extensions.reveal_type` and `typing.reveal_type`) we still emit the reveal-type diagnostic only once. This is probably unlikely in practice, but it doesn't hurt to handle it smoothly. (It comes up now because we don't support `version_info` checks yet, so `typing_extensions.reveal_type` is actually that union.) --------- Co-authored-by: Alex Waygood <[email protected]>
1 parent 4aca9b9 commit 7aae809

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,12 +762,25 @@ impl<'db> CallOutcome<'db> {
762762
} => {
763763
let mut not_callable = vec![];
764764
let mut union_builder = UnionBuilder::new(db);
765+
let mut revealed = false;
765766
for outcome in &**outcomes {
766-
let return_ty = if let Self::NotCallable { not_callable_ty } = outcome {
767-
not_callable.push(*not_callable_ty);
768-
Type::Unknown
769-
} else {
770-
outcome.unwrap_with_diagnostic(db, node, builder)
767+
let return_ty = match outcome {
768+
Self::NotCallable { not_callable_ty } => {
769+
not_callable.push(*not_callable_ty);
770+
Type::Unknown
771+
}
772+
Self::RevealType {
773+
return_ty,
774+
revealed_ty: _,
775+
} => {
776+
if revealed {
777+
*return_ty
778+
} else {
779+
revealed = true;
780+
outcome.unwrap_with_diagnostic(db, node, builder)
781+
}
782+
}
783+
_ => outcome.unwrap_with_diagnostic(db, node, builder),
771784
};
772785
union_builder = union_builder.add(return_ty);
773786
}
@@ -841,6 +854,15 @@ impl<'db> FunctionType<'db> {
841854
})
842855
}
843856

857+
/// Return true if this is a symbol with given name from `typing` or `typing_extensions`.
858+
pub(crate) fn is_typing_symbol(self, db: &'db dyn Db, name: &str) -> bool {
859+
name == self.name(db)
860+
&& file_to_module(db, self.definition(db).file(db)).is_some_and(|module| {
861+
module.search_path().is_standard_library()
862+
&& matches!(&**module.name(), "typing" | "typing_extensions")
863+
})
864+
}
865+
844866
pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
845867
self.decorators(db).contains(&decorator)
846868
}

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ impl<'db> TypeInferenceBuilder<'db> {
705705
}
706706

707707
let function_type = FunctionType::new(self.db, name.id.clone(), definition, decorator_tys);
708-
let function_ty = if function_type.is_stdlib_symbol(self.db, "typing", "reveal_type") {
708+
let function_ty = if function_type.is_typing_symbol(self.db, "reveal_type") {
709709
Type::RevealTypeFunction(function_type)
710710
} else {
711711
Type::Function(function_type)
@@ -2761,6 +2761,44 @@ mod tests {
27612761
Ok(())
27622762
}
27632763

2764+
#[test]
2765+
fn reveal_type_aliased() -> anyhow::Result<()> {
2766+
let mut db = setup_db();
2767+
2768+
db.write_dedented(
2769+
"/src/a.py",
2770+
"
2771+
from typing import reveal_type as rt
2772+
2773+
x = 1
2774+
rt(x)
2775+
",
2776+
)?;
2777+
2778+
assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]);
2779+
2780+
Ok(())
2781+
}
2782+
2783+
#[test]
2784+
fn reveal_type_typing_extensions() -> anyhow::Result<()> {
2785+
let mut db = setup_db();
2786+
2787+
db.write_dedented(
2788+
"/src/a.py",
2789+
"
2790+
import typing_extensions
2791+
2792+
x = 1
2793+
typing_extensions.reveal_type(x)
2794+
",
2795+
)?;
2796+
2797+
assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]);
2798+
2799+
Ok(())
2800+
}
2801+
27642802
#[test]
27652803
fn follow_import_to_class() -> anyhow::Result<()> {
27662804
let mut db = setup_db();

0 commit comments

Comments
 (0)