Skip to content

Commit a87b27c

Browse files
AlexWaygoodcarljm
andauthored
[red-knot] Add support for relative imports (#12910)
Co-authored-by: Carl Meyer <[email protected]>
1 parent 9b73532 commit a87b27c

File tree

5 files changed

+230
-15
lines changed

5 files changed

+230
-15
lines changed

crates/red_knot_python_semantic/src/module_name.rs

+18
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,24 @@ impl ModuleName {
168168
};
169169
Some(Self(name))
170170
}
171+
172+
/// Extend `self` with the components of `other`
173+
///
174+
/// # Examples
175+
///
176+
/// ```
177+
/// use red_knot_python_semantic::ModuleName;
178+
///
179+
/// let mut module_name = ModuleName::new_static("foo").unwrap();
180+
/// module_name.extend(&ModuleName::new_static("bar").unwrap());
181+
/// assert_eq!(&module_name, "foo.bar");
182+
/// module_name.extend(&ModuleName::new_static("baz.eggs.ham").unwrap());
183+
/// assert_eq!(&module_name, "foo.bar.baz.eggs.ham");
184+
/// ```
185+
pub fn extend(&mut self, other: &ModuleName) {
186+
self.0.push('.');
187+
self.0.push_str(other);
188+
}
171189
}
172190

173191
impl Deref for ModuleName {

crates/red_knot_python_semantic/src/module_resolver/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::iter::FusedIterator;
22

33
pub(crate) use module::Module;
44
pub use resolver::resolve_module;
5-
pub(crate) use resolver::SearchPaths;
5+
pub(crate) use resolver::{file_to_module, SearchPaths};
66
use ruff_db::system::SystemPath;
77
pub use typeshed::vendored_typeshed_stubs;
88

crates/red_knot_python_semantic/src/module_resolver/module.rs

+6
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,9 @@ pub enum ModuleKind {
7777
/// A python package (`foo/__init__.py` or `foo/__init__.pyi`)
7878
Package,
7979
}
80+
81+
impl ModuleKind {
82+
pub const fn is_package(self) -> bool {
83+
matches!(self, ModuleKind::Package)
84+
}
85+
}

crates/red_knot_python_semantic/src/types/infer.rs

+203-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
//!
2121
//! Inferring types at any of the three region granularities returns a [`TypeInference`], which
2222
//! holds types for every [`Definition`] and expression within the inferred region.
23+
use std::num::NonZeroU32;
24+
2325
use rustc_hash::FxHashMap;
2426
use salsa;
2527
use salsa::plumbing::AsId;
@@ -31,7 +33,7 @@ use ruff_python_ast::{ExprContext, TypeParams};
3133

3234
use crate::builtins::builtins_scope;
3335
use crate::module_name::ModuleName;
34-
use crate::module_resolver::resolve_module;
36+
use crate::module_resolver::{file_to_module, resolve_module};
3537
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
3638
use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey};
3739
use crate::semantic_index::expression::Expression;
@@ -822,7 +824,7 @@ impl<'db> TypeInferenceBuilder<'db> {
822824
asname: _,
823825
} = alias;
824826

825-
let module_ty = self.module_ty_from_name(name);
827+
let module_ty = self.module_ty_from_name(ModuleName::new(name));
826828
self.types.definitions.insert(definition, module_ty);
827829
}
828830

@@ -860,20 +862,68 @@ impl<'db> TypeInferenceBuilder<'db> {
860862
self.infer_optional_expression(cause.as_deref());
861863
}
862864

865+
/// Given a `from .foo import bar` relative import, resolve the relative module
866+
/// we're importing `bar` from into an absolute [`ModuleName`]
867+
/// using the name of the module we're currently analyzing.
868+
///
869+
/// - `level` is the number of dots at the beginning of the relative module name:
870+
/// - `from .foo.bar import baz` => `level == 1`
871+
/// - `from ...foo.bar import baz` => `level == 3`
872+
/// - `tail` is the relative module name stripped of all leading dots:
873+
/// - `from .foo import bar` => `tail == "foo"`
874+
/// - `from ..foo.bar import baz` => `tail == "foo.bar"`
875+
fn relative_module_name(&self, tail: Option<&str>, level: NonZeroU32) -> Option<ModuleName> {
876+
let Some(module) = file_to_module(self.db, self.file) else {
877+
tracing::debug!("Failed to resolve file {:?} to a module", self.file);
878+
return None;
879+
};
880+
let mut level = level.get();
881+
if module.kind().is_package() {
882+
level -= 1;
883+
}
884+
let mut module_name = module.name().to_owned();
885+
for _ in 0..level {
886+
module_name = module_name.parent()?;
887+
}
888+
if let Some(tail) = tail {
889+
if let Some(valid_tail) = ModuleName::new(tail) {
890+
module_name.extend(&valid_tail);
891+
} else {
892+
tracing::debug!("Failed to resolve relative import due to invalid syntax");
893+
return None;
894+
}
895+
}
896+
Some(module_name)
897+
}
898+
863899
fn infer_import_from_definition(
864900
&mut self,
865901
import_from: &ast::StmtImportFrom,
866902
alias: &ast::Alias,
867903
definition: Definition<'db>,
868904
) {
869-
let ast::StmtImportFrom { module, .. } = import_from;
870-
let module_ty = if let Some(module) = module {
871-
self.module_ty_from_name(module)
905+
// TODO:
906+
// - Absolute `*` imports (`from collections import *`)
907+
// - Relative `*` imports (`from ...foo import *`)
908+
// - Submodule imports (`from collections import abc`,
909+
// where `abc` is a submodule of the `collections` package)
910+
//
911+
// For the last item, see the currently skipped tests
912+
// `follow_relative_import_bare_to_module()` and
913+
// `follow_nonexistent_import_bare_to_module()`.
914+
let ast::StmtImportFrom { module, level, .. } = import_from;
915+
tracing::trace!("Resolving imported object {alias:?} from statement {import_from:?}");
916+
let module_name = if let Some(level) = NonZeroU32::new(*level) {
917+
self.relative_module_name(module.as_deref(), level)
872918
} else {
873-
// TODO support relative imports
874-
Type::Unknown
919+
let module_name = module
920+
.as_ref()
921+
.expect("Non-relative import should always have a non-None `module`!");
922+
ModuleName::new(module_name)
875923
};
876924

925+
let module_ty = self.module_ty_from_name(module_name);
926+
877927
let ast::Alias {
878928
range: _,
879929
name,
@@ -896,11 +946,10 @@ impl<'db> TypeInferenceBuilder<'db> {
896946
}
897947
}
898948

899-
fn module_ty_from_name(&self, name: &ast::Identifier) -> Type<'db> {
900-
let module = ModuleName::new(&name.id).and_then(|name| resolve_module(self.db, name));
901-
module
902-
.map(|module| Type::Module(module.file()))
903-
.unwrap_or(Type::Unbound)
949+
fn module_ty_from_name(&self, module_name: Option<ModuleName>) -> Type<'db> {
950+
module_name
951+
.and_then(|module_name| resolve_module(self.db, module_name))
952+
.map_or(Type::Unbound, |module| Type::Module(module.file()))
904953
}
905954

906955
fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type<'db> {
@@ -1710,6 +1759,148 @@ mod tests {
17101759
Ok(())
17111760
}
17121761

1762+
#[test]
1763+
fn follow_relative_import_simple() -> anyhow::Result<()> {
1764+
let mut db = setup_db();
1765+
1766+
db.write_files([
1767+
("src/package/__init__.py", ""),
1768+
("src/package/foo.py", "X = 42"),
1769+
("src/package/bar.py", "from .foo import X"),
1770+
])?;
1771+
1772+
assert_public_ty(&db, "src/package/bar.py", "X", "Literal[42]");
1773+
1774+
Ok(())
1775+
}
1776+
1777+
#[test]
1778+
fn follow_nonexistent_relative_import_simple() -> anyhow::Result<()> {
1779+
let mut db = setup_db();
1780+
1781+
db.write_files([
1782+
("src/package/__init__.py", ""),
1783+
("src/package/bar.py", "from .foo import X"),
1784+
])?;
1785+
1786+
assert_public_ty(&db, "src/package/bar.py", "X", "Unbound");
1787+
1788+
Ok(())
1789+
}
1790+
1791+
#[test]
1792+
fn follow_relative_import_dotted() -> anyhow::Result<()> {
1793+
let mut db = setup_db();
1794+
1795+
db.write_files([
1796+
("src/package/__init__.py", ""),
1797+
("src/package/foo/bar/baz.py", "X = 42"),
1798+
("src/package/bar.py", "from .foo.bar.baz import X"),
1799+
])?;
1800+
1801+
assert_public_ty(&db, "src/package/bar.py", "X", "Literal[42]");
1802+
1803+
Ok(())
1804+
}
1805+
1806+
#[test]
1807+
fn follow_relative_import_bare_to_package() -> anyhow::Result<()> {
1808+
let mut db = setup_db();
1809+
1810+
db.write_files([
1811+
("src/package/__init__.py", "X = 42"),
1812+
("src/package/bar.py", "from . import X"),
1813+
])?;
1814+
1815+
assert_public_ty(&db, "src/package/bar.py", "X", "Literal[42]");
1816+
1817+
Ok(())
1818+
}
1819+
1820+
#[test]
1821+
fn follow_nonexistent_relative_import_bare_to_package() -> anyhow::Result<()> {
1822+
let mut db = setup_db();
1823+
db.write_files([("src/package/bar.py", "from . import X")])?;
1824+
assert_public_ty(&db, "src/package/bar.py", "X", "Unbound");
1825+
Ok(())
1826+
}
1827+
1828+
#[ignore = "TODO: Submodule imports possibly not supported right now?"]
1829+
#[test]
1830+
fn follow_relative_import_bare_to_module() -> anyhow::Result<()> {
1831+
let mut db = setup_db();
1832+
1833+
db.write_files([
1834+
("src/package/__init__.py", ""),
1835+
("src/package/foo.py", "X = 42"),
1836+
("src/package/bar.py", "from . import foo; y = foo.X"),
1837+
])?;
1838+
1839+
assert_public_ty(&db, "src/package/bar.py", "y", "Literal[42]");
1840+
1841+
Ok(())
1842+
}
1843+
1844+
#[ignore = "TODO: Submodule imports possibly not supported right now?"]
1845+
#[test]
1846+
fn follow_nonexistent_import_bare_to_module() -> anyhow::Result<()> {
1847+
let mut db = setup_db();
1848+
1849+
db.write_files([
1850+
("src/package/__init__.py", ""),
1851+
("src/package/bar.py", "from . import foo"),
1852+
])?;
1853+
1854+
assert_public_ty(&db, "src/package/bar.py", "foo", "Unbound");
1855+
1856+
Ok(())
1857+
}
1858+
1859+
#[test]
1860+
fn follow_relative_import_from_dunder_init() -> anyhow::Result<()> {
1861+
let mut db = setup_db();
1862+
1863+
db.write_files([
1864+
("src/package/__init__.py", "from .foo import X"),
1865+
("src/package/foo.py", "X = 42"),
1866+
])?;
1867+
1868+
assert_public_ty(&db, "src/package/__init__.py", "X", "Literal[42]");
1869+
1870+
Ok(())
1871+
}
1872+
1873+
#[test]
1874+
fn follow_nonexistent_relative_import_from_dunder_init() -> anyhow::Result<()> {
1875+
let mut db = setup_db();
1876+
db.write_files([("src/package/__init__.py", "from .foo import X")])?;
1877+
assert_public_ty(&db, "src/package/__init__.py", "X", "Unbound");
1878+
Ok(())
1879+
}
1880+
1881+
#[test]
1882+
fn follow_very_relative_import() -> anyhow::Result<()> {
1883+
let mut db = setup_db();
1884+
1885+
db.write_files([
1886+
("src/package/__init__.py", ""),
1887+
("src/package/foo.py", "X = 42"),
1888+
(
1889+
"src/package/subpackage/subsubpackage/bar.py",
1890+
"from ...foo import X",
1891+
),
1892+
])?;
1893+
1894+
assert_public_ty(
1895+
&db,
1896+
"src/package/subpackage/subsubpackage/bar.py",
1897+
"X",
1898+
"Literal[42]",
1899+
);
1900+
1901+
Ok(())
1902+
}
1903+
17131904
#[test]
17141905
fn resolve_base_class_by_name() -> anyhow::Result<()> {
17151906
let mut db = setup_db();

crates/ruff_benchmark/benches/red_knot.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ fn benchmark_incremental(criterion: &mut Criterion) {
8989
let Case { db, parser, .. } = case;
9090
let result = db.check_file(*parser).unwrap();
9191

92-
assert_eq!(result.len(), 111);
92+
assert_eq!(result.len(), 29);
9393
},
9494
BatchSize::SmallInput,
9595
);
@@ -104,7 +104,7 @@ fn benchmark_cold(criterion: &mut Criterion) {
104104
let Case { db, parser, .. } = case;
105105
let result = db.check_file(*parser).unwrap();
106106

107-
assert_eq!(result.len(), 111);
107+
assert_eq!(result.len(), 29);
108108
},
109109
BatchSize::SmallInput,
110110
);

0 commit comments

Comments
 (0)