Skip to content

Commit d056d09

Browse files
authored
[red-knot] add if-statement support to FlowGraph (#11673)
## Summary Add if-statement support to FlowGraph. This introduces branches and joins in the graph for the first time. ## Test Plan Added tests.
1 parent 1645be0 commit d056d09

File tree

2 files changed

+240
-40
lines changed

2 files changed

+240
-40
lines changed

crates/red_knot/src/symbols.rs

Lines changed: 153 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ pub(crate) enum Definition {
184184
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
185185
Assignment(TypedNodeKey<ast::StmtAssign>),
186186
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
187+
None,
187188
// TODO with statements, except handlers, function args...
188189
}
189190

@@ -288,8 +289,8 @@ impl SymbolTable {
288289
let flow_node_id = self.flow_graph.ast_to_flow[&node_key];
289290
ReachableDefinitionsIterator {
290291
table: self,
291-
flow_node_id,
292292
symbol_id,
293+
pending: vec![flow_node_id],
293294
}
294295
}
295296

@@ -545,25 +546,30 @@ where
545546
#[derive(Debug)]
546547
pub(crate) struct ReachableDefinitionsIterator<'a> {
547548
table: &'a SymbolTable,
548-
flow_node_id: FlowNodeId,
549549
symbol_id: SymbolId,
550+
pending: Vec<FlowNodeId>,
550551
}
551552

552553
impl<'a> Iterator for ReachableDefinitionsIterator<'a> {
553554
type Item = Definition;
554555

555556
fn next(&mut self) -> Option<Self::Item> {
556557
loop {
557-
match &self.table.flow_graph.flow_nodes_by_id[self.flow_node_id] {
558-
FlowNode::Start => return None,
558+
let flow_node_id = self.pending.pop()?;
559+
match &self.table.flow_graph.flow_nodes_by_id[flow_node_id] {
560+
FlowNode::Start => return Some(Definition::None),
559561
FlowNode::Definition(def_node) => {
560562
if def_node.symbol_id == self.symbol_id {
561-
// we found a definition; previous definitions along this path are not
562-
// reachable
563-
self.flow_node_id = FlowGraph::start();
564563
return Some(def_node.definition.clone());
565564
}
566-
self.flow_node_id = def_node.predecessor;
565+
self.pending.push(def_node.predecessor);
566+
}
567+
FlowNode::Branch(branch_node) => {
568+
self.pending.push(branch_node.predecessor);
569+
}
570+
FlowNode::Phi(phi_node) => {
571+
self.pending.push(phi_node.first_predecessor);
572+
self.pending.push(phi_node.second_predecessor);
567573
}
568574
}
569575
}
@@ -579,15 +585,31 @@ struct FlowNodeId;
579585
enum FlowNode {
580586
Start,
581587
Definition(DefinitionFlowNode),
588+
Branch(BranchFlowNode),
589+
Phi(PhiFlowNode),
582590
}
583591

592+
/// A Definition node represents a point in control flow where a symbol is defined
584593
#[derive(Debug)]
585594
struct DefinitionFlowNode {
586595
symbol_id: SymbolId,
587596
definition: Definition,
588597
predecessor: FlowNodeId,
589598
}
590599

600+
/// A Branch node represents a branch in control flow
601+
#[derive(Debug)]
602+
struct BranchFlowNode {
603+
predecessor: FlowNodeId,
604+
}
605+
606+
/// A Phi node represents a join point where control flow paths come together
607+
#[derive(Debug)]
608+
struct PhiFlowNode {
609+
first_predecessor: FlowNodeId,
610+
second_predecessor: FlowNodeId,
611+
}
612+
591613
#[derive(Debug, Default)]
592614
struct FlowGraph {
593615
flow_nodes_by_id: IndexVec<FlowNodeId, FlowNode>,
@@ -636,6 +658,10 @@ impl SymbolTableBuilder {
636658
.add_or_update_symbol(self.cur_scope(), identifier, flags)
637659
}
638660

661+
fn new_flow_node(&mut self, node: FlowNode) -> FlowNodeId {
662+
self.table.flow_graph.flow_nodes_by_id.push(node)
663+
}
664+
639665
fn add_or_update_symbol_with_def(
640666
&mut self,
641667
identifier: &str,
@@ -647,15 +673,11 @@ impl SymbolTableBuilder {
647673
.entry(symbol_id)
648674
.or_default()
649675
.push(definition.clone());
650-
let new_flow_node_id = self
651-
.table
652-
.flow_graph
653-
.flow_nodes_by_id
654-
.push(FlowNode::Definition(DefinitionFlowNode {
655-
definition,
656-
symbol_id,
657-
predecessor: self.current_flow_node(),
658-
}));
676+
let new_flow_node_id = self.new_flow_node(FlowNode::Definition(DefinitionFlowNode {
677+
definition,
678+
symbol_id,
679+
predecessor: self.current_flow_node(),
680+
}));
659681
self.set_current_flow_node(new_flow_node_id);
660682
symbol_id
661683
}
@@ -871,13 +893,127 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
871893
ast::visitor::preorder::walk_stmt(self, stmt);
872894
self.current_definition = None;
873895
}
896+
ast::Stmt::If(node) => {
897+
// we visit the if "test" condition first regardless
898+
self.visit_expr(&node.test);
899+
900+
// create branch node: does the if test pass or not?
901+
let if_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode {
902+
predecessor: self.current_flow_node(),
903+
}));
904+
905+
// visit the body of the `if` clause
906+
self.set_current_flow_node(if_branch);
907+
self.visit_body(&node.body);
908+
909+
// Flow node for the last if/elif condition branch; represents the "no branch
910+
// taken yet" possibility (where "taking a branch" means that the condition in an
911+
// if or elif evaluated to true and control flow went into that clause).
912+
let mut prior_branch = if_branch;
913+
914+
// Flow node for the state after the prior if/elif/else clause; represents "we have
915+
// taken one of the branches up to this point." Initially set to the post-if-clause
916+
// state, later will be set to the phi node joining that possible path with the
917+
// possibility that we took a later if/elif/else clause instead.
918+
let mut post_prior_clause = self.current_flow_node();
919+
920+
// Flag to mark if the final clause is an "else" -- if so, that means the "match no
921+
// clauses" path is not possible, we have to go through one of the clauses.
922+
let mut last_branch_is_else = false;
923+
924+
for clause in &node.elif_else_clauses {
925+
if clause.test.is_some() {
926+
// This is an elif clause. Create a new branch node. Its predecessor is the
927+
// previous branch node, because we can only take one branch in an entire
928+
// if/elif/else chain, so if we take this branch, it can only be because we
929+
// didn't take the previous one.
930+
prior_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode {
931+
predecessor: prior_branch,
932+
}));
933+
self.set_current_flow_node(prior_branch);
934+
} else {
935+
// This is an else clause. No need to create a branch node; there's no
936+
// branch here, if we haven't taken any previous branch, we definitely go
937+
// into the "else" clause.
938+
self.set_current_flow_node(prior_branch);
939+
last_branch_is_else = true;
940+
}
941+
self.visit_elif_else_clause(clause);
942+
// Update `post_prior_clause` to a new phi node joining the possibility that we
943+
// took any of the previous branches with the possibility that we took the one
944+
// just visited.
945+
post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode {
946+
first_predecessor: self.current_flow_node(),
947+
second_predecessor: post_prior_clause,
948+
}));
949+
}
950+
951+
if !last_branch_is_else {
952+
// Final branch was not an "else", which means it's possible we took zero
953+
// branches in the entire if/elif chain, so we need one more phi node to join
954+
// the "no branches taken" possibility.
955+
post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode {
956+
first_predecessor: post_prior_clause,
957+
second_predecessor: prior_branch,
958+
}));
959+
}
960+
961+
// Onward, with current flow node set to our final Phi node.
962+
self.set_current_flow_node(post_prior_clause);
963+
}
874964
_ => {
875965
ast::visitor::preorder::walk_stmt(self, stmt);
876966
}
877967
}
878968
}
879969
}
880970

971+
impl std::fmt::Display for FlowGraph {
972+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
973+
writeln!(f, "flowchart TD")?;
974+
for (id, node) in self.flow_nodes_by_id.iter_enumerated() {
975+
write!(f, " id{}", id.as_u32())?;
976+
match node {
977+
FlowNode::Start => writeln!(f, r"[\Start/]")?,
978+
FlowNode::Definition(def_node) => {
979+
writeln!(f, r"(Define symbol {})", def_node.symbol_id.as_u32())?;
980+
writeln!(
981+
f,
982+
r" id{}-->id{}",
983+
def_node.predecessor.as_u32(),
984+
id.as_u32()
985+
)?;
986+
}
987+
FlowNode::Branch(branch_node) => {
988+
writeln!(f, r"{{Branch}}")?;
989+
writeln!(
990+
f,
991+
r" id{}-->id{}",
992+
branch_node.predecessor.as_u32(),
993+
id.as_u32()
994+
)?;
995+
}
996+
FlowNode::Phi(phi_node) => {
997+
writeln!(f, r"((Phi))")?;
998+
writeln!(
999+
f,
1000+
r" id{}-->id{}",
1001+
phi_node.second_predecessor.as_u32(),
1002+
id.as_u32()
1003+
)?;
1004+
writeln!(
1005+
f,
1006+
r" id{}-->id{}",
1007+
phi_node.first_predecessor.as_u32(),
1008+
id.as_u32()
1009+
)?;
1010+
}
1011+
}
1012+
}
1013+
Ok(())
1014+
}
1015+
}
1016+
8811017
#[derive(Debug, Default)]
8821018
pub struct SymbolTablesStorage(KeyValueCache<FileId, Arc<SymbolTable>>);
8831019

crates/red_knot/src/types/infer.rs

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub fn infer_definition_type(
7979
let file_id = symbol.file_id;
8080

8181
match definition {
82+
Definition::None => Ok(Type::Unbound),
8283
Definition::Import(ImportDefinition {
8384
module: module_name,
8485
}) => {
@@ -223,7 +224,7 @@ mod tests {
223224
use crate::module::{
224225
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
225226
};
226-
use crate::symbols::{resolve_global_symbol, symbol_table, GlobalSymbolId};
227+
use crate::symbols::resolve_global_symbol;
227228
use crate::types::{infer_symbol_public_type, Type};
228229
use crate::Name;
229230

@@ -399,30 +400,93 @@ mod tests {
399400
#[test]
400401
fn resolve_visible_def() -> anyhow::Result<()> {
401402
let case = create_test()?;
402-
let db = &case.db;
403403

404-
let path = case.src.path().join("a.py");
405-
std::fs::write(path, "y = 1; y = 2; x = y")?;
406-
let file = resolve_module(db, ModuleName::new("a"))?
407-
.expect("module should be found")
408-
.path(db)?
409-
.file();
410-
let symbols = symbol_table(db, file)?;
411-
let x_sym = symbols
412-
.root_symbol_id_by_name("x")
413-
.expect("x symbol should be found");
414-
415-
let ty = infer_symbol_public_type(
416-
db,
417-
GlobalSymbolId {
418-
file_id: file,
419-
symbol_id: x_sym,
420-
},
404+
write_to_path(&case, "a.py", "y = 1; y = 2; x = y")?;
405+
406+
assert_public_type(&case, "a", "x", "Literal[2]")
407+
}
408+
409+
#[test]
410+
fn join_paths() -> anyhow::Result<()> {
411+
let case = create_test()?;
412+
413+
write_to_path(
414+
&case,
415+
"a.py",
416+
"
417+
y = 1
418+
y = 2
419+
if flag:
420+
y = 3
421+
x = y
422+
",
421423
)?;
422424

423-
let jar = HasJar::<SemanticJar>::jar(db)?;
424-
assert!(matches!(ty, Type::IntLiteral(_)));
425-
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[2]");
426-
Ok(())
425+
assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3])")
426+
}
427+
428+
#[test]
429+
fn maybe_unbound() -> anyhow::Result<()> {
430+
let case = create_test()?;
431+
432+
write_to_path(
433+
&case,
434+
"a.py",
435+
"
436+
if flag:
437+
y = 1
438+
x = y
439+
",
440+
)?;
441+
442+
assert_public_type(&case, "a", "x", "(Unbound | Literal[1])")
443+
}
444+
445+
#[test]
446+
fn if_elif_else() -> anyhow::Result<()> {
447+
let case = create_test()?;
448+
449+
write_to_path(
450+
&case,
451+
"a.py",
452+
"
453+
y = 1
454+
y = 2
455+
if flag:
456+
y = 3
457+
elif flag2:
458+
y = 4
459+
else:
460+
r = y
461+
y = 5
462+
s = y
463+
x = y
464+
",
465+
)?;
466+
467+
assert_public_type(&case, "a", "x", "(Literal[3] | Literal[4] | Literal[5])")?;
468+
assert_public_type(&case, "a", "r", "Literal[2]")?;
469+
assert_public_type(&case, "a", "s", "Literal[5]")
470+
}
471+
472+
#[test]
473+
fn if_elif() -> anyhow::Result<()> {
474+
let case = create_test()?;
475+
476+
write_to_path(
477+
&case,
478+
"a.py",
479+
"
480+
y = 1
481+
y = 2
482+
if flag:
483+
y = 3
484+
elif flag2:
485+
y = 4
486+
x = y
487+
",
488+
)?;
489+
490+
assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3] | Literal[4])")
427491
}
428492
}

0 commit comments

Comments
 (0)