Skip to content

Commit 0512428

Browse files
authored
[red-knot] Emit a diagnostic if the value of a starred expression or a yield from expression is not iterable (#13240)
1 parent 46a4573 commit 0512428

File tree

2 files changed

+133
-25
lines changed

2 files changed

+133
-25
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use infer::TypeInferenceBuilder;
12
use ruff_db::files::File;
23
use ruff_python_ast as ast;
34

@@ -400,28 +401,42 @@ impl<'db> Type<'db> {
400401
/// for y in x:
401402
/// pass
402403
/// ```
403-
///
404-
/// Returns `None` if `self` represents a type that is not iterable.
405-
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
404+
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
406405
// `self` represents the type of the iterable;
407406
// `__iter__` and `__next__` are both looked up on the class of the iterable:
408-
let type_of_class = self.to_meta_type(db);
407+
let iterable_meta_type = self.to_meta_type(db);
409408

410-
let dunder_iter_method = type_of_class.member(db, "__iter__");
409+
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
411410
if !dunder_iter_method.is_unbound() {
412-
let iterator_ty = dunder_iter_method.call(db)?;
411+
let Some(iterator_ty) = dunder_iter_method.call(db) else {
412+
return IterationOutcome::NotIterable {
413+
not_iterable_ty: *self,
414+
};
415+
};
416+
413417
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
414-
return dunder_next_method.call(db);
418+
return dunder_next_method
419+
.call(db)
420+
.map(|element_ty| IterationOutcome::Iterable { element_ty })
421+
.unwrap_or(IterationOutcome::NotIterable {
422+
not_iterable_ty: *self,
423+
});
415424
}
416425

417426
// Although it's not considered great practice,
418427
// classes that define `__getitem__` are also iterable,
419428
// even if they do not define `__iter__`.
420429
//
421-
// TODO this is only valid if the `__getitem__` method is annotated as
430+
// TODO(Alex) this is only valid if the `__getitem__` method is annotated as
422431
// accepting `int` or `SupportsIndex`
423-
let dunder_get_item_method = type_of_class.member(db, "__getitem__");
424-
dunder_get_item_method.call(db)
432+
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
433+
434+
dunder_get_item_method
435+
.call(db)
436+
.map(|element_ty| IterationOutcome::Iterable { element_ty })
437+
.unwrap_or(IterationOutcome::NotIterable {
438+
not_iterable_ty: *self,
439+
})
425440
}
426441

427442
#[must_use]
@@ -463,6 +478,28 @@ impl<'db> Type<'db> {
463478
}
464479
}
465480

481+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
482+
enum IterationOutcome<'db> {
483+
Iterable { element_ty: Type<'db> },
484+
NotIterable { not_iterable_ty: Type<'db> },
485+
}
486+
487+
impl<'db> IterationOutcome<'db> {
488+
fn unwrap_with_diagnostic(
489+
self,
490+
iterable_node: ast::AnyNodeRef,
491+
inference_builder: &mut TypeInferenceBuilder<'db>,
492+
) -> Type<'db> {
493+
match self {
494+
Self::Iterable { element_ty } => element_ty,
495+
Self::NotIterable { not_iterable_ty } => {
496+
inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty);
497+
Type::Unknown
498+
}
499+
}
500+
}
501+
}
502+
466503
#[salsa::interned]
467504
pub struct FunctionType<'db> {
468505
/// name of the function at definition
@@ -789,4 +826,65 @@ mod tests {
789826
&["Object of type 'NotIterable' is not iterable"],
790827
);
791828
}
829+
830+
#[test]
831+
fn starred_expressions_must_be_iterable() {
832+
let mut db = setup_db();
833+
834+
db.write_dedented(
835+
"src/a.py",
836+
"
837+
class NotIterable: pass
838+
839+
class Iterator:
840+
def __next__(self) -> int:
841+
return 42
842+
843+
class Iterable:
844+
def __iter__(self) -> Iterator:
845+
846+
x = [*NotIterable()]
847+
y = [*Iterable()]
848+
",
849+
)
850+
.unwrap();
851+
852+
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
853+
let a_file_diagnostics = super::check_types(&db, a_file);
854+
assert_diagnostic_messages(
855+
&a_file_diagnostics,
856+
&["Object of type 'NotIterable' is not iterable"],
857+
);
858+
}
859+
860+
#[test]
861+
fn yield_from_expression_must_be_iterable() {
862+
let mut db = setup_db();
863+
864+
db.write_dedented(
865+
"src/a.py",
866+
"
867+
class NotIterable: pass
868+
869+
class Iterator:
870+
def __next__(self) -> int:
871+
return 42
872+
873+
class Iterable:
874+
def __iter__(self) -> Iterator:
875+
876+
def generator_function():
877+
yield from Iterable()
878+
yield from NotIterable()
879+
",
880+
)
881+
.unwrap();
882+
883+
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
884+
let a_file_diagnostics = super::check_types(&db, a_file);
885+
assert_diagnostic_messages(
886+
&a_file_diagnostics,
887+
&["Object of type 'NotIterable' is not iterable"],
888+
);
889+
}
792890
}

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ impl<'db> TypeInference<'db> {
243243
/// Similarly, when we encounter a standalone-inferable expression (right-hand side of an
244244
/// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we
245245
/// don't infer its types more than once.
246-
struct TypeInferenceBuilder<'db> {
246+
pub(super) struct TypeInferenceBuilder<'db> {
247247
db: &'db dyn Db,
248248
index: &'db SemanticIndex<'db>,
249249
region: InferenceRegion<'db>,
@@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> {
10291029
self.infer_body(orelse);
10301030
}
10311031

1032+
/// Emit a diagnostic declaring that the object represented by `node` is not iterable
1033+
pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) {
1034+
self.add_diagnostic(
1035+
node,
1036+
"not-iterable",
1037+
format_args!(
1038+
"Object of type '{}' is not iterable",
1039+
not_iterable_ty.display(self.db)
1040+
),
1041+
);
1042+
}
1043+
10321044
fn infer_for_statement_definition(
10331045
&mut self,
10341046
target: &ast::ExprName,
@@ -1042,17 +1054,9 @@ impl<'db> TypeInferenceBuilder<'db> {
10421054
.types
10431055
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
10441056

1045-
let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| {
1046-
self.add_diagnostic(
1047-
iterable.into(),
1048-
"not-iterable",
1049-
format_args!(
1050-
"Object of type '{}' is not iterable",
1051-
iterable_ty.display(self.db)
1052-
),
1053-
);
1054-
Type::Unknown
1055-
});
1057+
let loop_var_value_ty = iterable_ty
1058+
.iterate(self.db)
1059+
.unwrap_with_diagnostic(iterable.into(), self);
10561060

10571061
self.types
10581062
.expressions
@@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> {
18121816
ctx: _,
18131817
} = starred;
18141818

1815-
self.infer_expression(value);
1819+
let iterable_ty = self.infer_expression(value);
1820+
iterable_ty
1821+
.iterate(self.db)
1822+
.unwrap_with_diagnostic(value.as_ref().into(), self);
18161823

18171824
// TODO
18181825
Type::Unknown
@@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> {
18301837
fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> {
18311838
let ast::ExprYieldFrom { range: _, value } = yield_from;
18321839

1833-
self.infer_expression(value);
1840+
let iterable_ty = self.infer_expression(value);
1841+
iterable_ty
1842+
.iterate(self.db)
1843+
.unwrap_with_diagnostic(value.as_ref().into(), self);
18341844

1835-
// TODO get type from awaitable
1845+
// TODO get type from `ReturnType` of generator
18361846
Type::Unknown
18371847
}
18381848

0 commit comments

Comments
 (0)