Skip to content

Commit 0a6dc8e

Browse files
Support __getitem__ type inference for subscripts (#13579)
## Summary Follow-up to #13562, to add support for "arbitrary" subscript operations.
1 parent 8d54996 commit 0a6dc8e

File tree

2 files changed

+310
-1
lines changed

2 files changed

+310
-1
lines changed

crates/red_knot_python_semantic/src/types.rs

+10
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,16 @@ impl<'db> Type<'db> {
401401
}
402402
}
403403

404+
/// Return true if the type is a class or a union of classes.
405+
pub fn is_class(&self, db: &'db dyn Db) -> bool {
406+
match self {
407+
Type::Union(union) => union.elements(db).iter().all(|ty| ty.is_class(db)),
408+
Type::Class(_) => true,
409+
// / TODO include type[X], once we add that type
410+
_ => false,
411+
}
412+
}
413+
404414
/// Return true if this type is a [subtype of] type `target`.
405415
///
406416
/// [subtype of]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence

crates/red_knot_python_semantic/src/types/infer.rs

+300-1
Original file line numberDiff line numberDiff line change
@@ -1322,6 +1322,22 @@ impl<'db> TypeInferenceBuilder<'db> {
13221322
);
13231323
}
13241324

1325+
/// Emit a diagnostic declaring that a type does not support subscripting.
1326+
pub(super) fn non_subscriptable_diagnostic(
1327+
&mut self,
1328+
node: AnyNodeRef,
1329+
non_subscriptable_ty: Type<'db>,
1330+
) {
1331+
self.add_diagnostic(
1332+
node,
1333+
"non-subscriptable",
1334+
format_args!(
1335+
"Cannot subscript object of type '{}' with no `__getitem__` method.",
1336+
non_subscriptable_ty.display(self.db)
1337+
),
1338+
);
1339+
}
1340+
13251341
fn infer_for_statement_definition(
13261342
&mut self,
13271343
target: &ast::ExprName,
@@ -2588,7 +2604,35 @@ impl<'db> TypeInferenceBuilder<'db> {
25882604
Type::Unknown
25892605
})
25902606
}
2591-
_ => Type::Todo,
2607+
(value_ty, slice_ty) => {
2608+
// Resolve the value to its class.
2609+
let value_meta_ty = value_ty.to_meta_type(self.db);
2610+
2611+
// If the class defines `__getitem__`, return its return type.
2612+
//
2613+
// See: https://docs.python.org/3/reference/datamodel.html#class-getitem-versus-getitem
2614+
let dunder_getitem_method = value_meta_ty.member(self.db, "__getitem__");
2615+
if !dunder_getitem_method.is_unbound() {
2616+
return dunder_getitem_method
2617+
.call(self.db, &[slice_ty])
2618+
.unwrap_with_diagnostic(self.db, value.as_ref().into(), self);
2619+
}
2620+
2621+
// Otherwise, if the value is itself a class and defines `__class_getitem__`,
2622+
// return its return type.
2623+
if value_ty.is_class(self.db) {
2624+
let dunder_class_getitem_method = value_ty.member(self.db, "__class_getitem__");
2625+
if !dunder_class_getitem_method.is_unbound() {
2626+
return dunder_class_getitem_method
2627+
.call(self.db, &[slice_ty])
2628+
.unwrap_with_diagnostic(self.db, value.as_ref().into(), self);
2629+
}
2630+
}
2631+
2632+
// Otherwise, emit a diagnostic.
2633+
self.non_subscriptable_diagnostic((&**value).into(), value_ty);
2634+
Type::Unknown
2635+
}
25922636
}
25932637
}
25942638

@@ -6723,6 +6767,261 @@ mod tests {
67236767
Ok(())
67246768
}
67256769

6770+
#[test]
6771+
fn subscript_getitem_unbound() -> anyhow::Result<()> {
6772+
let mut db = setup_db();
6773+
6774+
db.write_dedented(
6775+
"/src/a.py",
6776+
"
6777+
class NotSubscriptable:
6778+
pass
6779+
6780+
a = NotSubscriptable()[0]
6781+
",
6782+
)?;
6783+
6784+
assert_public_ty(&db, "/src/a.py", "a", "Unknown");
6785+
assert_file_diagnostics(
6786+
&db,
6787+
"/src/a.py",
6788+
&["Cannot subscript object of type 'NotSubscriptable' with no `__getitem__` method."],
6789+
);
6790+
6791+
Ok(())
6792+
}
6793+
6794+
#[test]
6795+
fn subscript_not_callable_getitem() -> anyhow::Result<()> {
6796+
let mut db = setup_db();
6797+
6798+
db.write_dedented(
6799+
"/src/a.py",
6800+
"
6801+
class NotSubscriptable:
6802+
__getitem__ = None
6803+
6804+
a = NotSubscriptable()[0]
6805+
",
6806+
)?;
6807+
6808+
assert_public_ty(&db, "/src/a.py", "a", "Unknown");
6809+
assert_file_diagnostics(
6810+
&db,
6811+
"/src/a.py",
6812+
&["Object of type 'None' is not callable."],
6813+
);
6814+
6815+
Ok(())
6816+
}
6817+
6818+
#[test]
6819+
fn subscript_str_literal() -> anyhow::Result<()> {
6820+
let mut db = setup_db();
6821+
6822+
db.write_dedented(
6823+
"/src/a.py",
6824+
"
6825+
def add(x: int, y: int) -> int:
6826+
return x + y
6827+
6828+
a = 'abcde'[add(0, 1)]
6829+
",
6830+
)?;
6831+
6832+
assert_public_ty(&db, "/src/a.py", "a", "str");
6833+
6834+
Ok(())
6835+
}
6836+
6837+
#[test]
6838+
fn subscript_getitem() -> anyhow::Result<()> {
6839+
let mut db = setup_db();
6840+
6841+
db.write_dedented(
6842+
"/src/a.py",
6843+
"
6844+
class Identity:
6845+
def __getitem__(self, index: int) -> int:
6846+
return index
6847+
6848+
a = Identity()[0]
6849+
",
6850+
)?;
6851+
6852+
assert_public_ty(&db, "/src/a.py", "a", "int");
6853+
6854+
Ok(())
6855+
}
6856+
6857+
#[test]
6858+
fn subscript_class_getitem() -> anyhow::Result<()> {
6859+
let mut db = setup_db();
6860+
6861+
db.write_dedented(
6862+
"/src/a.py",
6863+
"
6864+
class Identity:
6865+
def __class_getitem__(cls, item: int) -> str:
6866+
return item
6867+
6868+
a = Identity[0]
6869+
",
6870+
)?;
6871+
6872+
assert_public_ty(&db, "/src/a.py", "a", "str");
6873+
6874+
Ok(())
6875+
}
6876+
6877+
#[test]
6878+
fn subscript_getitem_union() -> anyhow::Result<()> {
6879+
let mut db = setup_db();
6880+
6881+
db.write_dedented(
6882+
"/src/a.py",
6883+
"
6884+
flag = True
6885+
6886+
class Identity:
6887+
if flag:
6888+
def __getitem__(self, index: int) -> int:
6889+
return index
6890+
else:
6891+
def __getitem__(self, index: int) -> str:
6892+
return str(index)
6893+
6894+
a = Identity()[0]
6895+
",
6896+
)?;
6897+
6898+
assert_public_ty(&db, "/src/a.py", "a", "int | str");
6899+
6900+
Ok(())
6901+
}
6902+
6903+
#[test]
6904+
fn subscript_class_getitem_union() -> anyhow::Result<()> {
6905+
let mut db = setup_db();
6906+
6907+
db.write_dedented(
6908+
"/src/a.py",
6909+
"
6910+
flag = True
6911+
6912+
class Identity:
6913+
if flag:
6914+
def __class_getitem__(cls, item: int) -> str:
6915+
return item
6916+
else:
6917+
def __class_getitem__(cls, item: int) -> int:
6918+
return item
6919+
6920+
a = Identity[0]
6921+
",
6922+
)?;
6923+
6924+
assert_public_ty(&db, "/src/a.py", "a", "str | int");
6925+
6926+
Ok(())
6927+
}
6928+
6929+
#[test]
6930+
fn subscript_class_getitem_class_union() -> anyhow::Result<()> {
6931+
let mut db = setup_db();
6932+
6933+
db.write_dedented(
6934+
"/src/a.py",
6935+
"
6936+
flag = True
6937+
6938+
class Identity1:
6939+
def __class_getitem__(cls, item: int) -> str:
6940+
return item
6941+
6942+
class Identity2:
6943+
def __class_getitem__(cls, item: int) -> int:
6944+
return item
6945+
6946+
if flag:
6947+
a = Identity1
6948+
else:
6949+
a = Identity2
6950+
6951+
b = a[0]
6952+
",
6953+
)?;
6954+
6955+
assert_public_ty(&db, "/src/a.py", "a", "Literal[Identity1, Identity2]");
6956+
assert_public_ty(&db, "/src/a.py", "b", "str | int");
6957+
6958+
Ok(())
6959+
}
6960+
6961+
#[test]
6962+
fn subscript_class_getitem_unbound_method_union() -> anyhow::Result<()> {
6963+
let mut db = setup_db();
6964+
6965+
db.write_dedented(
6966+
"/src/a.py",
6967+
"
6968+
flag = True
6969+
6970+
if flag:
6971+
class Identity:
6972+
def __class_getitem__(self, x: int) -> str:
6973+
pass
6974+
else:
6975+
class Identity:
6976+
pass
6977+
6978+
a = Identity[42]
6979+
",
6980+
)?;
6981+
6982+
assert_public_ty(&db, "/src/a.py", "a", "str | Unknown");
6983+
6984+
assert_file_diagnostics(
6985+
&db,
6986+
"/src/a.py",
6987+
&["Object of type 'Literal[__class_getitem__] | Unbound' is not callable (due to union element 'Unbound')."],
6988+
);
6989+
6990+
Ok(())
6991+
}
6992+
6993+
#[test]
6994+
fn subscript_class_getitem_non_class_union() -> anyhow::Result<()> {
6995+
let mut db = setup_db();
6996+
6997+
db.write_dedented(
6998+
"/src/a.py",
6999+
"
7000+
flag = True
7001+
7002+
if flag:
7003+
class Identity:
7004+
def __class_getitem__(self, x: int) -> str:
7005+
pass
7006+
else:
7007+
Identity = 1
7008+
7009+
a = Identity[42]
7010+
",
7011+
)?;
7012+
7013+
// TODO this should _probably_ emit `str | Unknown` instead of `Unknown`.
7014+
assert_public_ty(&db, "/src/a.py", "a", "Unknown");
7015+
7016+
assert_file_diagnostics(
7017+
&db,
7018+
"/src/a.py",
7019+
&["Cannot subscript object of type 'Literal[Identity] | Literal[1]' with no `__getitem__` method."],
7020+
);
7021+
7022+
Ok(())
7023+
}
7024+
67267025
#[test]
67277026
fn dunder_call() -> anyhow::Result<()> {
67287027
let mut db = setup_db();

0 commit comments

Comments
 (0)