Skip to content

Commit 6ffb961

Browse files
authored
red-knot: Change resolve_global_symbol to take Module as an argument (#11723)
1 parent 64165be commit 6ffb961

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

crates/red_knot/src/lint.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use ruff_python_parser::Parsed;
1010
use crate::cache::KeyValueCache;
1111
use crate::db::{LintDb, LintJar, QueryResult};
1212
use crate::files::FileId;
13-
use crate::module::ModuleName;
13+
use crate::module::{resolve_module, ModuleName};
1414
use crate::parse::parse;
1515
use crate::source::{source_text, Source};
1616
use crate::symbols::{
@@ -145,9 +145,7 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> {
145145
// TODO we should have a special marker on the real typing module (from typeshed) so if you
146146
// have your own "typing" module in your project, we don't consider it THE typing module (and
147147
// same for other stdlib modules that our lint rules care about)
148-
let Some(typing_override) =
149-
resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")?
150-
else {
148+
let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else {
151149
// TODO once we bundle typeshed, this should be unreachable!()
152150
return Ok(());
153151
};
@@ -235,6 +233,18 @@ impl<'a> SemanticLintContext<'a> {
235233
pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator<Item = String>) {
236234
self.diagnostics.get_mut().extend(diagnostics);
237235
}
236+
237+
pub fn resolve_global_symbol(
238+
&self,
239+
module: &str,
240+
symbol_name: &str,
241+
) -> QueryResult<Option<GlobalSymbolId>> {
242+
let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else {
243+
return Ok(None);
244+
};
245+
246+
resolve_global_symbol(self.db.upcast(), module, symbol_name)
247+
}
238248
}
239249

240250
#[derive(Debug)]

crates/red_knot/src/symbols.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::ast_ids::{NodeKey, TypedNodeKey};
1818
use crate::cache::KeyValueCache;
1919
use crate::db::{QueryResult, SemanticDb, SemanticJar};
2020
use crate::files::FileId;
21-
use crate::module::{resolve_module, ModuleName};
21+
use crate::module::{Module, ModuleName};
2222
use crate::parse::parse;
2323
use crate::Name;
2424

@@ -35,13 +35,10 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult<Arc<Sym
3535
#[tracing::instrument(level = "debug", skip(db))]
3636
pub fn resolve_global_symbol(
3737
db: &dyn SemanticDb,
38-
module: ModuleName,
38+
module: Module,
3939
name: &str,
4040
) -> QueryResult<Option<GlobalSymbolId>> {
41-
let Some(typing_module) = resolve_module(db, module)? else {
42-
return Ok(None);
43-
};
44-
let typing_file = typing_module.path(db)?.file();
41+
let typing_file = module.path(db)?.file();
4542
let typing_table = symbol_table(db, typing_file)?;
4643
let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else {
4744
return Ok(None);

crates/red_knot/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ impl ModuleTypeId {
392392
}
393393

394394
fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
395-
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
395+
if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? {
396396
Ok(Some(infer_symbol_public_type(db, symbol_id)?))
397397
} else {
398398
Ok(None)

crates/red_knot/src/types/infer.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,11 @@ pub fn infer_definition_type(
9696
// TODO relative imports
9797
assert!(matches!(level, 0));
9898
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
99-
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
99+
let Some(module) = resolve_module(db, module_name.clone())? else {
100+
return Ok(Type::Unknown);
101+
};
102+
103+
if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? {
100104
infer_symbol_public_type(db, remote_symbol)
101105
} else {
102106
Ok(Type::Unknown)
@@ -248,30 +252,34 @@ mod tests {
248252
Ok(TestCase { temp_dir, db, src })
249253
}
250254

251-
fn write_to_path(case: &TestCase, relpath: &str, contents: &str) -> anyhow::Result<()> {
252-
let path = case.src.path().join(relpath);
255+
fn write_to_path(case: &TestCase, relative_path: &str, contents: &str) -> anyhow::Result<()> {
256+
let path = case.src.path().join(relative_path);
253257
std::fs::write(path, contents)?;
254258
Ok(())
255259
}
256260

257-
fn get_public_type(case: &TestCase, modname: &str, varname: &str) -> anyhow::Result<Type> {
261+
fn get_public_type(
262+
case: &TestCase,
263+
module_name: &str,
264+
variable_name: &str,
265+
) -> anyhow::Result<Type> {
258266
let db = &case.db;
259-
let symbol =
260-
resolve_global_symbol(db, ModuleName::new(modname), varname)?.expect("symbol to exist");
267+
let module = resolve_module(db, ModuleName::new(module_name))?.expect("Module to exist");
268+
let symbol = resolve_global_symbol(db, module, variable_name)?.expect("symbol to exist");
261269

262270
Ok(infer_symbol_public_type(db, symbol)?)
263271
}
264272

265273
fn assert_public_type(
266274
case: &TestCase,
267-
modname: &str,
268-
varname: &str,
269-
tyname: &str,
275+
module_name: &str,
276+
variable_name: &str,
277+
type_name: &str,
270278
) -> anyhow::Result<()> {
271-
let ty = get_public_type(case, modname, varname)?;
279+
let ty = get_public_type(case, module_name, variable_name)?;
272280

273281
let jar = HasJar::<SemanticJar>::jar(&case.db)?;
274-
assert_eq!(format!("{}", ty.display(&jar.type_store)), tyname);
282+
assert_eq!(format!("{}", ty.display(&jar.type_store)), type_name);
275283
Ok(())
276284
}
277285

@@ -399,8 +407,8 @@ mod tests {
399407
.expect("module should be found")
400408
.path(db)?
401409
.file();
402-
let syms = symbol_table(db, file)?;
403-
let x_sym = syms
410+
let symbols = symbol_table(db, file)?;
411+
let x_sym = symbols
404412
.root_symbol_id_by_name("x")
405413
.expect("x symbol should be found");
406414

0 commit comments

Comments
 (0)