Skip to content

Commit c274971

Browse files
msullivanJukkaL
andauthored
[mypyc] Support yields while values are live (python#16305)
Also support await while temporary values are live. --------- Co-authored-by: Jukka Lehtosalo <[email protected]>
1 parent 057f8ad commit c274971

File tree

11 files changed

+392
-19
lines changed

11 files changed

+392
-19
lines changed

mypyc/analysis/dataflow.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Cast,
1818
ComparisonOp,
1919
ControlOp,
20+
DecRef,
2021
Extend,
2122
Float,
2223
FloatComparisonOp,
@@ -25,6 +26,7 @@
2526
GetAttr,
2627
GetElementPtr,
2728
Goto,
29+
IncRef,
2830
InitStatic,
2931
Integer,
3032
IntOp,
@@ -77,12 +79,11 @@ def __str__(self) -> str:
7779
return f"exits: {exits}\nsucc: {self.succ}\npred: {self.pred}"
7880

7981

80-
def get_cfg(blocks: list[BasicBlock]) -> CFG:
82+
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
8183
"""Calculate basic block control-flow graph.
8284
83-
The result is a dictionary like this:
84-
85-
basic block index -> (successors blocks, predecesssor blocks)
85+
If use_yields is set, then we treat returns inserted by yields as gotos
86+
instead of exits.
8687
"""
8788
succ_map = {}
8889
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
@@ -92,7 +93,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG:
9293
isinstance(op, ControlOp) for op in block.ops[:-1]
9394
), "Control-flow ops must be at the end of blocks"
9495

95-
succ = list(block.terminator.targets())
96+
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
97+
succ = [block.terminator.yield_target]
98+
else:
99+
succ = list(block.terminator.targets())
96100
if not succ:
97101
exits.add(block)
98102

@@ -474,6 +478,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
474478
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
475479
return non_trivial_sources(op), set()
476480

481+
def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
482+
return set(), set()
483+
484+
def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
485+
return set(), set()
486+
477487

478488
def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
479489
"""Calculate live registers at each CFG location.

mypyc/codegen/emitmodule.py

+11
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from mypyc.transform.flag_elimination import do_flag_elimination
6262
from mypyc.transform.lower import lower_ir
6363
from mypyc.transform.refcount import insert_ref_count_opcodes
64+
from mypyc.transform.spill import insert_spills
6465
from mypyc.transform.uninit import insert_uninit_checks
6566

6667
# All of the modules being compiled are divided into "groups". A group
@@ -228,6 +229,12 @@ def compile_scc_to_ir(
228229
if errors.num_errors > 0:
229230
return modules
230231

232+
env_user_functions = {}
233+
for module in modules.values():
234+
for cls in module.classes:
235+
if cls.env_user_function:
236+
env_user_functions[cls.env_user_function] = cls
237+
231238
for module in modules.values():
232239
for fn in module.functions:
233240
# Insert uninit checks.
@@ -236,6 +243,10 @@ def compile_scc_to_ir(
236243
insert_exception_handling(fn)
237244
# Insert refcount handling.
238245
insert_ref_count_opcodes(fn)
246+
247+
if fn in env_user_functions:
248+
insert_spills(fn, env_user_functions[fn])
249+
239250
# Switch to lower abstraction level IR.
240251
lower_ir(fn, compiler_options)
241252
# Perform optimizations.

mypyc/ir/class_ir.py

+7
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def __init__(
196196
# value of an attribute is the same as the error value.
197197
self.bitmap_attrs: list[str] = []
198198

199+
# If this is a generator environment class, what is the actual method for it
200+
self.env_user_function: FuncIR | None = None
201+
199202
def __repr__(self) -> str:
200203
return (
201204
"ClassIR("
@@ -394,6 +397,7 @@ def serialize(self) -> JsonDict:
394397
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
395398
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
396399
"init_self_leak": self.init_self_leak,
400+
"env_user_function": self.env_user_function.id if self.env_user_function else None,
397401
}
398402

399403
@classmethod
@@ -446,6 +450,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
446450
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
447451
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
448452
ir.init_self_leak = data["init_self_leak"]
453+
ir.env_user_function = (
454+
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
455+
)
449456

450457
return ir
451458

0 commit comments

Comments
 (0)