Skip to content

Commit a6732dc

Browse files
committed
Thread context through block in transforms correctly and efficiently
Fixes #14319
1 parent d09dd2a commit a6732dc

File tree

6 files changed

+47
-21
lines changed

6 files changed

+47
-21
lines changed

compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,7 @@ class TreeMapWithImplicits extends tpd.TreeMapWithPreciseStatContexts {
4848
override def transform(tree: Tree)(using Context): Tree = {
4949
try tree match {
5050
case Block(stats, expr) =>
51-
inContext(nestedScopeCtx(stats)) {
52-
if stats.exists(_.isInstanceOf[Import]) then
53-
// need to transform stats and expr together to account for import visibility
54-
val stats1 = transformStats(stats :+ expr, ctx.owner)
55-
cpy.Block(tree)(stats1.init, stats1.last)
56-
else super.transform(tree)
57-
}
51+
super.transform(tree)(using nestedScopeCtx(stats))
5852
case tree: DefDef =>
5953
inContext(localCtx(tree)) {
6054
cpy.DefDef(tree)(

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,8 +1399,8 @@ object Trees {
13991399
cpy.NamedArg(tree)(name, transform(arg))
14001400
case Assign(lhs, rhs) =>
14011401
cpy.Assign(tree)(transform(lhs), transform(rhs))
1402-
case Block(stats, expr) =>
1403-
cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr))
1402+
case blk: Block =>
1403+
transformBlock(blk)
14041404
case If(cond, thenp, elsep) =>
14051405
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
14061406
case Closure(env, meth, tpt) =>
@@ -1489,6 +1489,8 @@ object Trees {
14891489

14901490
def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
14911491
transform(trees)
1492+
def transformBlock(blk: Block)(using Context): Block =
1493+
cpy.Block(blk)(transformStats(blk.stats, ctx.owner), transform(blk.expr))
14921494
def transform(trees: List[Tree])(using Context): List[Tree] =
14931495
flatten(trees mapConserve (transform(_)))
14941496
def transformSub[Tr <: Tree](tree: Tr)(using Context): Tr =

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,9 +1157,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11571157
* - be tail-recursive where possible
11581158
* - don't re-allocate trees where nothing has changed
11591159
*/
1160-
inline def mapStatements(exprOwner: Symbol, inline op: Tree => Context ?=> Tree)(using Context): List[Tree] =
1160+
inline def mapStatements[T](
1161+
exprOwner: Symbol,
1162+
inline op: Tree => Context ?=> Tree,
1163+
inline wrapResult: List[Tree] => Context ?=> T)(using Context): T =
11611164
@tailrec
1162-
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): List[Tree] =
1165+
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): T =
11631166
pending match
11641167
case stat :: rest =>
11651168
val statCtx = stat match
@@ -1182,8 +1185,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11821185
case _ => buf += stat1
11831186
loop(buf, rest, rest)(using restCtx)
11841187
case nil =>
1185-
if mapped == null then unchanged
1186-
else mapped.prependToList(unchanged)
1188+
wrapResult(
1189+
if mapped == null then unchanged
1190+
else mapped.prependToList(unchanged))
11871191

11881192
loop(null, trees, trees)
11891193
end mapStatements
@@ -1195,8 +1199,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11951199
* - imports are reflected in the contexts of subsequent statements
11961200
*/
11971201
class TreeMapWithPreciseStatContexts(cpy: TreeCopier = tpd.cpy) extends TreeMap(cpy):
1198-
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
1199-
trees.mapStatements(exprOwner, transform(_))
1202+
def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
1203+
trees.mapStatements(exprOwner, transform(_), wrapResult)
1204+
final override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
1205+
transformStats(trees, exprOwner, sameStats)
1206+
override def transformBlock(blk: Block)(using Context) =
1207+
transformStats(blk.stats, ctx.owner,
1208+
stats1 => ctx ?=> cpy.Block(blk)(stats1, transform(blk.expr)))
1209+
1210+
val sameStats: List[Tree] => Context ?=> List[Tree] = stats => stats
12001211

12011212
/** Map Inlined nodes, NamedArgs, Blocks with no statements and local references to underlying arguments.
12021213
* Also drops Inline and Block with no statements.

compiler/src/dotty/tools/dotc/transform/MegaPhase.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,7 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
296296
}
297297
case tree: Block =>
298298
inContext(prepBlock(tree, start)(using outerCtx)) {
299-
val stats = transformStats(tree.stats, ctx.owner, start)
300-
val expr = transformTree(tree.expr, start)
301-
goBlock(cpy.Block(tree)(stats, expr), start)
299+
transformBlock(tree, start)
302300
}
303301
case tree: TypeApply =>
304302
inContext(prepTypeApply(tree, start)(using outerCtx)) {
@@ -434,9 +432,20 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
434432

435433
def transformStats(trees: List[Tree], exprOwner: Symbol, start: Int)(using Context): List[Tree] =
436434
val nestedCtx = prepStats(trees, start)
437-
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start))(using nestedCtx)
435+
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start), stats1 => stats1)(using nestedCtx)
438436
goStats(trees1, start)(using nestedCtx)
439437

438+
def transformBlock(tree: Block, start: Int)(using Context): Tree =
439+
val nestedCtx = prepStats(tree.stats, start)
440+
val block1 = tree.stats.mapStatements(ctx.owner,
441+
transformTree(_, start),
442+
stats1 => ctx ?=> {
443+
val stats2 = goStats(stats1, start)(using nestedCtx)
444+
val expr2 = transformTree(tree.expr, start)
445+
cpy.Block(tree)(stats2, expr2)
446+
})(using nestedCtx)
447+
goBlock(block1, start)
448+
440449
def transformUnit(tree: Tree)(using Context): Tree = {
441450
val nestedCtx = prepUnit(tree, 0)
442451
val tree1 = transformTree(tree, 0)(using nestedCtx)

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
458458
throw ex
459459
}
460460

461-
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
462-
try super.transformStats(trees, exprOwner)
461+
override def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
462+
try super.transformStats(trees, exprOwner, wrapResult)
463463
finally Checking.checkExperimentalImports(trees)
464464

465465
/** Transforms the rhs tree into a its default tree if it is in an `erased` val/def.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import java.nio.file.FileSystems
2+
import java.util.ArrayList
3+
4+
def directorySeparator: String =
5+
import scala.language.unsafeNulls
6+
FileSystems.getDefault().getSeparator()
7+
8+
def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
9+
import scala.language.unsafeNulls
10+
xs.get(0).get(0).get(0)

0 commit comments

Comments
 (0)