Skip to content

Commit 915f4e8

Browse files
authored
Merge pull request #14523 from dotty-staging/optimize-mapblock
Thread context through block in transforms correctly and efficiently
2 parents 334b37a + 3ef8e2f commit 915f4e8

File tree

9 files changed

+141
-36
lines changed

9 files changed

+141
-36
lines changed

compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,7 @@ class TreeMapWithImplicits extends tpd.TreeMapWithPreciseStatContexts {
4848
override def transform(tree: Tree)(using Context): Tree = {
4949
try tree match {
5050
case Block(stats, expr) =>
51-
inContext(nestedScopeCtx(stats)) {
52-
if stats.exists(_.isInstanceOf[Import]) then
53-
// need to transform stats and expr together to account for import visibility
54-
val stats1 = transformStats(stats :+ expr, ctx.owner)
55-
cpy.Block(tree)(stats1.init, stats1.last)
56-
else super.transform(tree)
57-
}
51+
super.transform(tree)(using nestedScopeCtx(stats))
5852
case tree: DefDef =>
5953
inContext(localCtx(tree)) {
6054
cpy.DefDef(tree)(

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,8 +1399,8 @@ object Trees {
13991399
cpy.NamedArg(tree)(name, transform(arg))
14001400
case Assign(lhs, rhs) =>
14011401
cpy.Assign(tree)(transform(lhs), transform(rhs))
1402-
case Block(stats, expr) =>
1403-
cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr))
1402+
case blk: Block =>
1403+
transformBlock(blk)
14041404
case If(cond, thenp, elsep) =>
14051405
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
14061406
case Closure(env, meth, tpt) =>
@@ -1489,6 +1489,8 @@ object Trees {
14891489

14901490
def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
14911491
transform(trees)
1492+
def transformBlock(blk: Block)(using Context): Block =
1493+
cpy.Block(blk)(transformStats(blk.stats, ctx.owner), transform(blk.expr))
14921494
def transform(trees: List[Tree])(using Context): List[Tree] =
14931495
flatten(trees mapConserve (transform(_)))
14941496
def transformSub[Tr <: Tree](tree: Tr)(using Context): Tr =

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,9 +1157,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11571157
* - be tail-recursive where possible
11581158
* - don't re-allocate trees where nothing has changed
11591159
*/
1160-
inline def mapStatements(exprOwner: Symbol, inline op: Tree => Context ?=> Tree)(using Context): List[Tree] =
1160+
inline def mapStatements[T](
1161+
exprOwner: Symbol,
1162+
inline op: Tree => Context ?=> Tree,
1163+
inline wrapResult: List[Tree] => Context ?=> T)(using Context): T =
11611164
@tailrec
1162-
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): List[Tree] =
1165+
def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): T =
11631166
pending match
11641167
case stat :: rest =>
11651168
val statCtx = stat match
@@ -1182,8 +1185,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11821185
case _ => buf += stat1
11831186
loop(buf, rest, rest)(using restCtx)
11841187
case nil =>
1185-
if mapped == null then unchanged
1186-
else mapped.prependToList(unchanged)
1188+
wrapResult(
1189+
if mapped == null then unchanged
1190+
else mapped.prependToList(unchanged))
11871191

11881192
loop(null, trees, trees)
11891193
end mapStatements
@@ -1195,8 +1199,15 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11951199
* - imports are reflected in the contexts of subsequent statements
11961200
*/
11971201
class TreeMapWithPreciseStatContexts(cpy: TreeCopier = tpd.cpy) extends TreeMap(cpy):
1198-
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
1199-
trees.mapStatements(exprOwner, transform(_))
1202+
def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
1203+
trees.mapStatements(exprOwner, transform(_), wrapResult)
1204+
final override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
1205+
transformStats(trees, exprOwner, sameStats)
1206+
override def transformBlock(blk: Block)(using Context) =
1207+
transformStats(blk.stats, ctx.owner,
1208+
stats1 => ctx ?=> cpy.Block(blk)(stats1, transform(blk.expr)))
1209+
1210+
val sameStats: List[Tree] => Context ?=> List[Tree] = stats => stats
12001211

12011212
/** Map Inlined nodes, NamedArgs, Blocks with no statements and local references to underlying arguments.
12021213
* Also drops Inline and Block with no statements.

compiler/src/dotty/tools/dotc/transform/ForwardDepChecks.scala

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ object ForwardDepChecks:
2525
}
2626

2727
/** A class to help in forward reference checking */
28-
class LevelInfo(outerLevelAndIndex: LevelAndIndex, stats: List[Tree])(using Context)
28+
class LevelInfo(val outer: OptLevelInfo, val owner: Symbol, stats: List[Tree])(using Context)
2929
extends OptLevelInfo {
3030
override val levelAndIndex: LevelAndIndex =
31-
stats.foldLeft(outerLevelAndIndex, 0) {(mi, stat) =>
31+
stats.foldLeft(outer.levelAndIndex, 0) {(mi, stat) =>
3232
val (m, idx) = mi
3333
val m1 = stat match {
3434
case stat: MemberDef => m.updated(stat.symbol, (this, idx))
@@ -71,7 +71,7 @@ class ForwardDepChecks extends MiniPhase:
7171

7272
override def prepareForStats(trees: List[Tree])(using Context): Context =
7373
if (ctx.owner.isTerm)
74-
ctx.fresh.updateStore(LevelInfo, new LevelInfo(currentLevel.levelAndIndex, trees))
74+
ctx.fresh.updateStore(LevelInfo, new LevelInfo(currentLevel, ctx.owner, trees))
7575
else ctx
7676

7777
override def transformValDef(tree: ValDef)(using Context): ValDef =
@@ -89,19 +89,39 @@ class ForwardDepChecks extends MiniPhase:
8989
tree
9090
}
9191

92-
override def transformApply(tree: Apply)(using Context): Apply = {
93-
if (isSelfConstrCall(tree)) {
94-
assert(currentLevel.isInstanceOf[LevelInfo], s"${ctx.owner}/" + i"$tree")
95-
val level = currentLevel.asInstanceOf[LevelInfo]
96-
if (level.maxIndex > 0) {
97-
// An implementation restriction to avoid VerifyErrors and lazyvals mishaps; see SI-4717
98-
report.debuglog("refsym = " + level.refSym)
99-
report.error("forward reference not allowed from self constructor invocation",
100-
ctx.source.atSpan(level.refSpan))
101-
}
102-
}
92+
/** Check that self constructor call does not contain references to vals or defs
93+
* defined later in the secondary constructor's right hand side. This is tricky
94+
* since the complete self constructor might itself be a block that originated from
95+
* expanding named and default parameters. In that case we have to go outwards
96+
* and find the enclosing expression that consists of that block. Test cases in
97+
* {pos,neg}/complex-self-call.scala.
98+
*/
99+
private def checkSelfConstructorCall()(using Context): Unit =
100+
// Find level info corresponding to constructor's RHS. This is the info of the
101+
// outermost LevelInfo that has the constructor as owner.
102+
def rhsLevelInfo(l: OptLevelInfo): OptLevelInfo = l match
103+
case l: LevelInfo if l.owner == ctx.owner =>
104+
rhsLevelInfo(l.outer) match
105+
case l1: LevelInfo => l1
106+
case _ => l
107+
case _ =>
108+
NoLevelInfo
109+
110+
rhsLevelInfo(currentLevel) match
111+
case level: LevelInfo =>
112+
if level.maxIndex > 0 then
113+
report.debuglog("refsym = " + level.refSym.showLocated)
114+
report.error("forward reference not allowed from self constructor invocation",
115+
ctx.source.atSpan(level.refSpan))
116+
case _ =>
117+
assert(false, s"${ctx.owner.showLocated}")
118+
end checkSelfConstructorCall
119+
120+
override def transformApply(tree: Apply)(using Context): Apply =
121+
if (isSelfConstrCall(tree))
122+
assert(ctx.owner.isConstructor)
123+
checkSelfConstructorCall()
103124
tree
104-
}
105125

106126
override def transformNew(tree: New)(using Context): New = {
107127
currentLevel.enterReference(tree.tpe.typeSymbol, tree.span)

compiler/src/dotty/tools/dotc/transform/MegaPhase.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,7 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
296296
}
297297
case tree: Block =>
298298
inContext(prepBlock(tree, start)(using outerCtx)) {
299-
val stats = transformStats(tree.stats, ctx.owner, start)
300-
val expr = transformTree(tree.expr, start)
301-
goBlock(cpy.Block(tree)(stats, expr), start)
299+
transformBlock(tree, start)
302300
}
303301
case tree: TypeApply =>
304302
inContext(prepTypeApply(tree, start)(using outerCtx)) {
@@ -434,9 +432,20 @@ class MegaPhase(val miniPhases: Array[MiniPhase]) extends Phase {
434432

435433
def transformStats(trees: List[Tree], exprOwner: Symbol, start: Int)(using Context): List[Tree] =
436434
val nestedCtx = prepStats(trees, start)
437-
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start))(using nestedCtx)
435+
val trees1 = trees.mapStatements(exprOwner, transformTree(_, start), stats1 => stats1)(using nestedCtx)
438436
goStats(trees1, start)(using nestedCtx)
439437

438+
def transformBlock(tree: Block, start: Int)(using Context): Tree =
439+
val nestedCtx = prepStats(tree.stats, start)
440+
val block1 = tree.stats.mapStatements(ctx.owner,
441+
transformTree(_, start),
442+
stats1 => ctx ?=> {
443+
val stats2 = goStats(stats1, start)(using nestedCtx)
444+
val expr2 = transformTree(tree.expr, start)
445+
cpy.Block(tree)(stats2, expr2)
446+
})(using nestedCtx)
447+
goBlock(block1, start)
448+
440449
def transformUnit(tree: Tree)(using Context): Tree = {
441450
val nestedCtx = prepUnit(tree, 0)
442451
val tree1 = transformTree(tree, 0)(using nestedCtx)

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
458458
throw ex
459459
}
460460

461-
override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] =
462-
try super.transformStats(trees, exprOwner)
461+
override def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T =
462+
try super.transformStats(trees, exprOwner, wrapResult)
463463
finally Checking.checkExperimentalImports(trees)
464464

465465
/** Transforms the rhs tree into a its default tree if it is in an `erased` val/def.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import java.nio.file.FileSystems
2+
import java.util.ArrayList
3+
4+
def directorySeparator: String =
5+
import scala.language.unsafeNulls
6+
FileSystems.getDefault().getSeparator()
7+
8+
def getFirstOfFirst(xs: ArrayList[ArrayList[ArrayList[String]]]): String =
9+
import scala.language.unsafeNulls
10+
xs.get(0).get(0).get(0)

tests/neg/complex-self-call.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// An example extracted from akka that demonstrated a spurious
2+
// "forward reference not allowed from self constructor invocation" error.
3+
class Resizer
4+
class SupervisorStrategy
5+
class Pool
6+
object Pool:
7+
def defaultSupervisorStrategy: SupervisorStrategy = ???
8+
object Dispatchers:
9+
def DefaultDispatcherId = ???
10+
object Resizer:
11+
def fromConfig(config: Config): Option[Resizer] = ???
12+
13+
class Config:
14+
def getInt(str: String): Int = ???
15+
def hasPath(str: String): Boolean = ???
16+
17+
final case class BroadcastPool(
18+
nrOfInstances: Int,
19+
val resizer: Option[Resizer] = None,
20+
val supervisorStrategy: SupervisorStrategy = Pool.defaultSupervisorStrategy,
21+
val routerDispatcher: String = Dispatchers.DefaultDispatcherId,
22+
val usePoolDispatcher: Boolean = false)
23+
extends Pool:
24+
25+
def this(config: Config) =
26+
this(
27+
nrOfInstances = config.getInt("nr-of-instances"),
28+
resizer = resiz, // error: forward reference not allowed from self constructor invocation
29+
usePoolDispatcher = config.hasPath("pool-dispatcher"))
30+
def resiz = Resizer.fromConfig(config)

tests/pos/complex-self-call.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// An example extracted from akka that demonstrated a spurious
2+
// "forward reference not allowed from self constructor invocation" error.
3+
class Resizer
4+
class SupervisorStrategy
5+
class Pool
6+
object Pool:
7+
def defaultSupervisorStrategy: SupervisorStrategy = ???
8+
object Dispatchers:
9+
def DefaultDispatcherId = ???
10+
object Resizer:
11+
def fromConfig(config: Config): Option[Resizer] = ???
12+
13+
class Config:
14+
def getInt(str: String): Int = ???
15+
def hasPath(str: String): Boolean = ???
16+
17+
final case class BroadcastPool(
18+
nrOfInstances: Int,
19+
val resizer: Option[Resizer] = None,
20+
val supervisorStrategy: SupervisorStrategy = Pool.defaultSupervisorStrategy,
21+
val routerDispatcher: String = Dispatchers.DefaultDispatcherId,
22+
val usePoolDispatcher: Boolean = false)
23+
extends Pool:
24+
25+
def this(config: Config) =
26+
this(
27+
nrOfInstances = config.getInt("nr-of-instances"),
28+
resizer = Resizer.fromConfig(config),
29+
usePoolDispatcher = config.hasPath("pool-dispatcher"))

0 commit comments

Comments
 (0)