Skip to content

Commit 82dd5e6

Browse files
authored
[red-knot] resolve class members (#11256)
1 parent 6a1e555 commit 82dd5e6

File tree

3 files changed

+116
-27
lines changed

3 files changed

+116
-27
lines changed

crates/red_knot/src/symbols.rs

+31-14
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ pub(crate) struct Scope {
6868
name: Name,
6969
kind: ScopeKind,
7070
child_scopes: Vec<ScopeId>,
71-
// symbol IDs, hashed by symbol name
71+
/// symbol IDs, hashed by symbol name
7272
symbols_by_name: Map<SymbolId, ()>,
7373
}
7474

@@ -107,6 +107,7 @@ bitflags! {
107107
pub(crate) struct Symbol {
108108
name: Name,
109109
flags: SymbolFlags,
110+
scope_id: ScopeId,
110111
// kind: Kind,
111112
}
112113

@@ -141,7 +142,7 @@ pub(crate) enum Definition {
141142
// the small amount of information we need from the AST.
142143
Import(ImportDefinition),
143144
ImportFrom(ImportFromDefinition),
144-
ClassDef(TypedNodeKey<ast::StmtClassDef>),
145+
ClassDef(ClassDefinition),
145146
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
146147
Assignment(TypedNodeKey<ast::StmtAssign>),
147148
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
@@ -174,6 +175,12 @@ impl ImportFromDefinition {
174175
}
175176
}
176177

178+
#[derive(Clone, Debug)]
179+
pub(crate) struct ClassDefinition {
180+
pub(crate) node_key: TypedNodeKey<ast::StmtClassDef>,
181+
pub(crate) scope_id: ScopeId,
182+
}
183+
177184
#[derive(Debug, Clone)]
178185
pub enum Dependency {
179186
Module(ModuleName),
@@ -332,7 +339,11 @@ impl SymbolTable {
332339
*entry.key()
333340
}
334341
RawEntryMut::Vacant(entry) => {
335-
let id = self.symbols_by_id.push(Symbol { name, flags });
342+
let id = self.symbols_by_id.push(Symbol {
343+
name,
344+
flags,
345+
scope_id,
346+
});
336347
entry.insert_with_hasher(hash, id, (), |_| hash);
337348
id
338349
}
@@ -459,8 +470,8 @@ impl SymbolTableBuilder {
459470
symbol_id
460471
}
461472

462-
fn push_scope(&mut self, child_of: ScopeId, name: &str, kind: ScopeKind) -> ScopeId {
463-
let scope_id = self.table.add_child_scope(child_of, name, kind);
473+
fn push_scope(&mut self, name: &str, kind: ScopeKind) -> ScopeId {
474+
let scope_id = self.table.add_child_scope(self.cur_scope(), name, kind);
464475
self.scopes.push(scope_id);
465476
scope_id
466477
}
@@ -482,10 +493,10 @@ impl SymbolTableBuilder {
482493
&mut self,
483494
name: &str,
484495
params: &Option<Box<ast::TypeParams>>,
485-
nested: impl FnOnce(&mut Self),
486-
) {
496+
nested: impl FnOnce(&mut Self) -> ScopeId,
497+
) -> ScopeId {
487498
if let Some(type_params) = params {
488-
self.push_scope(self.cur_scope(), name, ScopeKind::Annotation);
499+
self.push_scope(name, ScopeKind::Annotation);
489500
for type_param in &type_params.type_params {
490501
let name = match type_param {
491502
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
@@ -495,10 +506,11 @@ impl SymbolTableBuilder {
495506
self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED);
496507
}
497508
}
498-
nested(self);
509+
let scope_id = nested(self);
499510
if params.is_some() {
500511
self.pop_scope();
501512
}
513+
scope_id
502514
}
503515
}
504516

@@ -525,21 +537,26 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
525537
// TODO need to capture more definition statements here
526538
match stmt {
527539
ast::Stmt::ClassDef(node) => {
528-
let def = Definition::ClassDef(TypedNodeKey::from_node(node));
529-
self.add_or_update_symbol_with_def(&node.name, def);
530-
self.with_type_params(&node.name, &node.type_params, |builder| {
531-
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class);
540+
let scope_id = self.with_type_params(&node.name, &node.type_params, |builder| {
541+
let scope_id = builder.push_scope(&node.name, ScopeKind::Class);
532542
ast::visitor::preorder::walk_stmt(builder, stmt);
533543
builder.pop_scope();
544+
scope_id
534545
});
546+
let def = Definition::ClassDef(ClassDefinition {
547+
node_key: TypedNodeKey::from_node(node),
548+
scope_id,
549+
});
550+
self.add_or_update_symbol_with_def(&node.name, def);
535551
}
536552
ast::Stmt::FunctionDef(node) => {
537553
let def = Definition::FunctionDef(TypedNodeKey::from_node(node));
538554
self.add_or_update_symbol_with_def(&node.name, def);
539555
self.with_type_params(&node.name, &node.type_params, |builder| {
540-
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function);
556+
let scope_id = builder.push_scope(&node.name, ScopeKind::Function);
541557
ast::visitor::preorder::walk_stmt(builder, stmt);
542558
builder.pop_scope();
559+
scope_id
543560
});
544561
}
545562
ast::Stmt::Import(ast::StmtImport { names, .. }) => {

crates/red_knot/src/types.rs

+42-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#![allow(dead_code)]
22
use crate::ast_ids::NodeKey;
3+
use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
34
use crate::files::FileId;
4-
use crate::symbols::SymbolId;
5+
use crate::symbols::{ScopeId, SymbolId};
56
use crate::{FxDashMap, FxIndexSet, Name};
67
use ruff_index::{newtype_index, IndexVec};
78
use rustc_hash::FxHashMap;
@@ -124,8 +125,15 @@ impl TypeStore {
124125
.add_function(name, decorators)
125126
}
126127

127-
fn add_class(&self, file_id: FileId, name: &str, bases: Vec<Type>) -> ClassTypeId {
128-
self.add_or_get_module(file_id).add_class(name, bases)
128+
fn add_class(
129+
&self,
130+
file_id: FileId,
131+
name: &str,
132+
scope_id: ScopeId,
133+
bases: Vec<Type>,
134+
) -> ClassTypeId {
135+
self.add_or_get_module(file_id)
136+
.add_class(name, scope_id, bases)
129137
}
130138

131139
fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
@@ -253,6 +261,24 @@ pub struct ClassTypeId {
253261
class_id: ModuleClassTypeId,
254262
}
255263

264+
impl ClassTypeId {
265+
fn get_own_class_member<Db>(self, db: &Db, name: &Name) -> QueryResult<Option<Type>>
266+
where
267+
Db: SemanticDb + HasJar<SemanticJar>,
268+
{
269+
// TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them
270+
let ClassType { scope_id, .. } = *db.jar()?.type_store.get_class(self);
271+
let table = db.symbol_table(self.file_id)?;
272+
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
273+
Ok(Some(db.infer_symbol_type(self.file_id, symbol_id)?))
274+
} else {
275+
Ok(None)
276+
}
277+
}
278+
279+
// TODO: get_own_instance_member, get_class_member, get_instance_member
280+
}
281+
256282
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
257283
pub struct UnionTypeId {
258284
file_id: FileId,
@@ -318,9 +344,10 @@ impl ModuleTypeStore {
318344
}
319345
}
320346

321-
fn add_class(&mut self, name: &str, bases: Vec<Type>) -> ClassTypeId {
347+
fn add_class(&mut self, name: &str, scope_id: ScopeId, bases: Vec<Type>) -> ClassTypeId {
322348
let class_id = self.classes.push(ClassType {
323349
name: Name::new(name),
350+
scope_id,
324351
// TODO: if no bases are given, that should imply [object]
325352
bases,
326353
});
@@ -405,7 +432,11 @@ impl std::fmt::Display for DisplayType<'_> {
405432

406433
#[derive(Debug)]
407434
pub(crate) struct ClassType {
435+
/// Name of the class at definition
408436
name: Name,
437+
/// `ScopeId` of the class body
438+
pub(crate) scope_id: ScopeId,
439+
/// Types of all class bases
409440
bases: Vec<Type>,
410441
}
411442

@@ -496,6 +527,7 @@ impl IntersectionType {
496527
#[cfg(test)]
497528
mod tests {
498529
use crate::files::Files;
530+
use crate::symbols::SymbolTable;
499531
use crate::types::{Type, TypeStore};
500532
use crate::FxIndexSet;
501533
use std::path::Path;
@@ -505,7 +537,7 @@ mod tests {
505537
let store = TypeStore::default();
506538
let files = Files::default();
507539
let file_id = files.intern(Path::new("/foo"));
508-
let id = store.add_class(file_id, "C", Vec::new());
540+
let id = store.add_class(file_id, "C", SymbolTable::root_scope_id(), Vec::new());
509541
assert_eq!(store.get_class(id).name(), "C");
510542
let inst = Type::Instance(id);
511543
assert_eq!(format!("{}", inst.display(&store)), "C");
@@ -528,8 +560,8 @@ mod tests {
528560
let mut store = TypeStore::default();
529561
let files = Files::default();
530562
let file_id = files.intern(Path::new("/foo"));
531-
let c1 = store.add_class(file_id, "C1", Vec::new());
532-
let c2 = store.add_class(file_id, "C2", Vec::new());
563+
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
564+
let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new());
533565
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
534566
let id = store.add_union(file_id, &elems);
535567
assert_eq!(
@@ -545,9 +577,9 @@ mod tests {
545577
let mut store = TypeStore::default();
546578
let files = Files::default();
547579
let file_id = files.intern(Path::new("/foo"));
548-
let c1 = store.add_class(file_id, "C1", Vec::new());
549-
let c2 = store.add_class(file_id, "C2", Vec::new());
550-
let c3 = store.add_class(file_id, "C3", Vec::new());
580+
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
581+
let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new());
582+
let c3 = store.add_class(file_id, "C3", SymbolTable::root_scope_id(), Vec::new());
551583
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
552584
let neg = vec![Type::Instance(c3)];
553585
let id = store.add_intersection(file_id, &pos, &neg);

crates/red_knot/src/types/infer.rs

+43-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use ruff_python_ast::AstNode;
44

55
use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
66
use crate::module::ModuleName;
7-
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
7+
use crate::symbols::{ClassDefinition, Definition, ImportFromDefinition, SymbolId};
88
use crate::types::Type;
99
use crate::FileId;
1010
use ruff_python_ast as ast;
@@ -51,7 +51,7 @@ where
5151
Type::Unknown
5252
}
5353
}
54-
Definition::ClassDef(node_key) => {
54+
Definition::ClassDef(ClassDefinition { node_key, scope_id }) => {
5555
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
5656
ty
5757
} else {
@@ -65,7 +65,8 @@ where
6565
bases.push(infer_expr_type(db, file_id, base)?);
6666
}
6767

68-
let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases));
68+
let ty =
69+
Type::Class(type_store.add_class(file_id, &node.name.id, *scope_id, bases));
6970
type_store.cache_node_type(file_id, *node_key.erased(), ty);
7071
ty
7172
}
@@ -133,6 +134,7 @@ mod tests {
133134
use crate::db::{HasJar, SemanticDb, SemanticJar};
134135
use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind};
135136
use crate::types::Type;
137+
use crate::Name;
136138

137139
// TODO with virtual filesystem we shouldn't have to write files to disk for these
138140
// tests
@@ -222,4 +224,42 @@ mod tests {
222224

223225
Ok(())
224226
}
227+
228+
#[test]
229+
fn resolve_method() -> anyhow::Result<()> {
230+
let case = create_test()?;
231+
let db = &case.db;
232+
233+
let path = case.src.path().join("mod.py");
234+
std::fs::write(path, "class C:\n def f(self): pass")?;
235+
let file = db
236+
.resolve_module(ModuleName::new("mod"))?
237+
.expect("module should be found")
238+
.path(db)?
239+
.file();
240+
let syms = db.symbol_table(file)?;
241+
let sym = syms
242+
.root_symbol_id_by_name("C")
243+
.expect("C symbol should be found");
244+
245+
let ty = db.infer_symbol_type(file, sym)?;
246+
247+
let Type::Class(class_id) = ty else {
248+
panic!("C is not a Class");
249+
};
250+
251+
let member_ty = class_id
252+
.get_own_class_member(db, &Name::new("f"))
253+
.expect("C.f to resolve");
254+
255+
let Some(Type::Function(func_id)) = member_ty else {
256+
panic!("C.f is not a Function");
257+
};
258+
259+
let jar = HasJar::<SemanticJar>::jar(db)?;
260+
let function = jar.type_store.get_function(func_id);
261+
assert_eq!(function.name(), "f");
262+
263+
Ok(())
264+
}
225265
}

0 commit comments

Comments
 (0)