|
| 1 | +use infer::TypeInferenceBuilder; |
1 | 2 | use ruff_db::files::File;
|
2 | 3 | use ruff_python_ast as ast;
|
3 | 4 |
|
@@ -400,28 +401,42 @@ impl<'db> Type<'db> {
|
400 | 401 | /// for y in x:
|
401 | 402 | /// pass
|
402 | 403 | /// ```
|
403 |
| - /// |
404 |
| - /// Returns `None` if `self` represents a type that is not iterable. |
405 |
| - fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> { |
| 404 | + fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> { |
406 | 405 | // `self` represents the type of the iterable;
|
407 | 406 | // `__iter__` and `__next__` are both looked up on the class of the iterable:
|
408 |
| - let type_of_class = self.to_meta_type(db); |
| 407 | + let iterable_meta_type = self.to_meta_type(db); |
409 | 408 |
|
410 |
| - let dunder_iter_method = type_of_class.member(db, "__iter__"); |
| 409 | + let dunder_iter_method = iterable_meta_type.member(db, "__iter__"); |
411 | 410 | if !dunder_iter_method.is_unbound() {
|
412 |
| - let iterator_ty = dunder_iter_method.call(db)?; |
| 411 | + let Some(iterator_ty) = dunder_iter_method.call(db) else { |
| 412 | + return IterationOutcome::NotIterable { |
| 413 | + not_iterable_ty: *self, |
| 414 | + }; |
| 415 | + }; |
| 416 | + |
413 | 417 | let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
|
414 |
| - return dunder_next_method.call(db); |
| 418 | + return dunder_next_method |
| 419 | + .call(db) |
| 420 | + .map(|element_ty| IterationOutcome::Iterable { element_ty }) |
| 421 | + .unwrap_or(IterationOutcome::NotIterable { |
| 422 | + not_iterable_ty: *self, |
| 423 | + }); |
415 | 424 | }
|
416 | 425 |
|
417 | 426 | // Although it's not considered great practice,
|
418 | 427 | // classes that define `__getitem__` are also iterable,
|
419 | 428 | // even if they do not define `__iter__`.
|
420 | 429 | //
|
421 |
| - // TODO this is only valid if the `__getitem__` method is annotated as |
| 430 | + // TODO(Alex) this is only valid if the `__getitem__` method is annotated as |
422 | 431 | // accepting `int` or `SupportsIndex`
|
423 |
| - let dunder_get_item_method = type_of_class.member(db, "__getitem__"); |
424 |
| - dunder_get_item_method.call(db) |
| 432 | + let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__"); |
| 433 | + |
| 434 | + dunder_get_item_method |
| 435 | + .call(db) |
| 436 | + .map(|element_ty| IterationOutcome::Iterable { element_ty }) |
| 437 | + .unwrap_or(IterationOutcome::NotIterable { |
| 438 | + not_iterable_ty: *self, |
| 439 | + }) |
425 | 440 | }
|
426 | 441 |
|
427 | 442 | #[must_use]
|
@@ -463,6 +478,28 @@ impl<'db> Type<'db> {
|
463 | 478 | }
|
464 | 479 | }
|
465 | 480 |
|
| 481 | +#[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| 482 | +enum IterationOutcome<'db> { |
| 483 | + Iterable { element_ty: Type<'db> }, |
| 484 | + NotIterable { not_iterable_ty: Type<'db> }, |
| 485 | +} |
| 486 | + |
| 487 | +impl<'db> IterationOutcome<'db> { |
| 488 | + fn unwrap_with_diagnostic( |
| 489 | + self, |
| 490 | + iterable_node: ast::AnyNodeRef, |
| 491 | + inference_builder: &mut TypeInferenceBuilder<'db>, |
| 492 | + ) -> Type<'db> { |
| 493 | + match self { |
| 494 | + Self::Iterable { element_ty } => element_ty, |
| 495 | + Self::NotIterable { not_iterable_ty } => { |
| 496 | + inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty); |
| 497 | + Type::Unknown |
| 498 | + } |
| 499 | + } |
| 500 | + } |
| 501 | +} |
| 502 | + |
466 | 503 | #[salsa::interned]
|
467 | 504 | pub struct FunctionType<'db> {
|
468 | 505 | /// name of the function at definition
|
@@ -789,4 +826,65 @@ mod tests {
|
789 | 826 | &["Object of type 'NotIterable' is not iterable"],
|
790 | 827 | );
|
791 | 828 | }
|
| 829 | + |
| 830 | + #[test] |
| 831 | + fn starred_expressions_must_be_iterable() { |
| 832 | + let mut db = setup_db(); |
| 833 | + |
| 834 | + db.write_dedented( |
| 835 | + "src/a.py", |
| 836 | + " |
| 837 | + class NotIterable: pass |
| 838 | +
|
| 839 | + class Iterator: |
| 840 | + def __next__(self) -> int: |
| 841 | + return 42 |
| 842 | +
|
| 843 | + class Iterable: |
| 844 | + def __iter__(self) -> Iterator: |
| 845 | +
|
| 846 | + x = [*NotIterable()] |
| 847 | + y = [*Iterable()] |
| 848 | + ", |
| 849 | + ) |
| 850 | + .unwrap(); |
| 851 | + |
| 852 | + let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); |
| 853 | + let a_file_diagnostics = super::check_types(&db, a_file); |
| 854 | + assert_diagnostic_messages( |
| 855 | + &a_file_diagnostics, |
| 856 | + &["Object of type 'NotIterable' is not iterable"], |
| 857 | + ); |
| 858 | + } |
| 859 | + |
| 860 | + #[test] |
| 861 | + fn yield_from_expression_must_be_iterable() { |
| 862 | + let mut db = setup_db(); |
| 863 | + |
| 864 | + db.write_dedented( |
| 865 | + "src/a.py", |
| 866 | + " |
| 867 | + class NotIterable: pass |
| 868 | +
|
| 869 | + class Iterator: |
| 870 | + def __next__(self) -> int: |
| 871 | + return 42 |
| 872 | +
|
| 873 | + class Iterable: |
| 874 | + def __iter__(self) -> Iterator: |
| 875 | +
|
| 876 | + def generator_function(): |
| 877 | + yield from Iterable() |
| 878 | + yield from NotIterable() |
| 879 | + ", |
| 880 | + ) |
| 881 | + .unwrap(); |
| 882 | + |
| 883 | + let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); |
| 884 | + let a_file_diagnostics = super::check_types(&db, a_file); |
| 885 | + assert_diagnostic_messages( |
| 886 | + &a_file_diagnostics, |
| 887 | + &["Object of type 'NotIterable' is not iterable"], |
| 888 | + ); |
| 889 | + } |
792 | 890 | }
|
0 commit comments