Skip to content

Commit 0afb5c6

Browse files
committed
Make memo work in traits
... and also make it work on several nested levels together.
1 parent e6856b0 commit 0afb5c6

File tree

8 files changed

+65
-34
lines changed

8 files changed

+65
-34
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,17 @@ object desugar {
144144

145145
// ----- Desugar methods -------------------------------------------------
146146

147+
def setterNeeded(flags: FlagSet, owner: Symbol) given Context =
148+
flags.is(Mutable) && owner.isClass && (!flags.isAllOf(PrivateLocal) || owner.is(Trait))
149+
147150
/** var x: Int = expr
148151
* ==>
149152
* def x: Int = expr
150153
* def x_=($1: <TypeTree()>): Unit = ()
151154
*/
152155
def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = {
153156
val vdef @ ValDef(name, tpt, rhs) = transformQuotedPatternName(vdef0)
154-
val mods = vdef.mods
155-
val setterNeeded =
156-
mods.is(Mutable) && ctx.owner.isClass && (!mods.isAllOf(PrivateLocal) || ctx.owner.is(Trait))
157-
if (setterNeeded) {
158-
// TODO: copy of vdef as getter needed?
159-
// val getter = ValDef(mods, name, tpt, rhs) withPos vdef.pos?
160-
// right now vdef maps via expandedTree to a thicket which concerns itself.
161-
// I don't see a problem with that but if there is one we can avoid it by making a copy here.
157+
if (setterNeeded(vdef.mods.flags, ctx.owner)) {
162158
val setterParam = makeSyntheticParameter(tpt = SetterParamTree().watching(vdef))
163159
// The rhs gets filled in later, when field is generated and getter has parameters (see Memoize miniphase)
164160
val setterRhs = if (vdef.rhs.isEmpty) EmptyTree else unitLiteral
@@ -168,7 +164,7 @@ object desugar {
168164
vparamss = (setterParam :: Nil) :: Nil,
169165
tpt = TypeTree(defn.UnitType),
170166
rhs = setterRhs
171-
).withMods((mods | Accessor) &~ (CaseAccessor | GivenOrImplicit | Lazy))
167+
).withMods((vdef.mods | Accessor) &~ (CaseAccessor | GivenOrImplicit | Lazy))
172168
Thicket(vdef, setter)
173169
}
174170
else vdef

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ object NameKinds {
299299
val UniqueInlineName: UniqueNameKind = new UniqueNameKind("$i")
300300
val InlineScrutineeName: UniqueNameKind = new UniqueNameKind("$scrutinee")
301301
val InlineBinderName: UniqueNameKind = new UniqueNameKind("$elem")
302-
val MemoCacheName: UniqueNameKind = new UniqueNameKind("memo$")
302+
val MemoCacheName: UniqueNameKind = new UniqueNameKind("$cache")
303303

304304
/** A kind of unique extension methods; Unlike other unique names, these can be
305305
* unmangled.

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ object StdNames {
482482
val materializeClassTag: N = "materializeClassTag"
483483
val materializeWeakTypeTag: N = "materializeWeakTypeTag"
484484
val materializeTypeTag: N = "materializeTypeTag"
485+
val memo: N = "memo"
485486
val mirror : N = "mirror"
486487
val moduleClass : N = "moduleClass"
487488
val name: N = "name"

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,10 @@ object Symbols {
553553

554554
/** This symbol entered into owner's scope (owner must be a class). */
555555
final def entered(implicit ctx: Context): this.type = {
556-
assert(this.owner.isClass, s"symbol ($this) entered the scope of non-class owner ${this.owner}") // !!! DEBUG
557-
this.owner.asClass.enter(this)
558-
if (this.is(Module)) this.owner.asClass.enter(this.moduleClass)
556+
if (this.owner.isClass) {
557+
this.owner.asClass.enter(this)
558+
if (this.is(Module)) this.owner.asClass.enter(this.moduleClass)
559+
}
559560
this
560561
}
561562

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,22 @@ object Inliner {
193193
/** For every occurrence of a memo cache symbol `memo$N` of type `T_N` in `tree`,
194194
* an assignment `val memo$N: T_N = null`
195195
*/
196-
def memoCacheDefs(tree: Tree) given Context: Set[ValDef] = {
197-
val memoCacheSyms = tree.deepFold[Set[TermSymbol]](Set.empty) {
198-
(syms, t) => t match {
199-
case Assign(lhs, _) if lhs.symbol.name.is(MemoCacheName) => syms + lhs.symbol.asTerm
200-
case _ => syms
196+
def memoCacheDefs(tree: Tree) given Context: List[ValOrDefDef] = {
197+
object memoRefs extends TreeTraverser {
198+
val syms = new mutable.LinkedHashSet[TermSymbol]
199+
def traverse(tree: Tree) given Context = tree match {
200+
case tree: RefTree if tree.symbol.name.is(MemoCacheName) =>
201+
syms += tree.symbol.asTerm
202+
case _: DefDef =>
203+
// don't traverse deeper; nested memo caches go next to nested method
204+
case _ =>
205+
traverseChildren(tree)
201206
}
202207
}
203-
memoCacheSyms.map(sym => ValDef(sym, Literal(Constant(null))).withSpan(sym.span))
208+
memoRefs.traverse(tree)
209+
for sym <- memoRefs.syms.toList yield
210+
(if (sym.isSetter) DefDef(sym, _ => Literal(Constant(())))
211+
else ValDef(sym, Literal(Constant(null)))).withSpan(sym.span)
204212
}
205213
}
206214

@@ -420,19 +428,22 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
420428
val argType = callTypeArgs.head.tpe
421429
val memoVar = ctx.newSymbol(
422430
owner = cacheOwner,
423-
name = MemoCacheName.fresh(),
431+
name = MemoCacheName.fresh(nme.memo),
424432
flags =
425433
if (cacheOwner.isTerm) Synthetic | Mutable
426434
else Synthetic | Mutable | Private | Local,
427435
info = OrType(argType, defn.NullType),
428-
coord = call.span)
429-
val memoSetter = ctx.newSymbol(
430-
owner = cacheOwner,
431-
name = memoVar.name.setterName,
432-
flags = memoVar.flags | Method | Accessor,
433-
info = MethodType(argType :: Nil, defn.UnitType),
434-
coord = call.span
435-
)
436+
coord = call.span).entered
437+
val memoSetter =
438+
if (desugar.setterNeeded(memoVar.flags, cacheOwner))
439+
ctx.newSymbol(
440+
owner = cacheOwner,
441+
name = memoVar.name.setterName,
442+
flags = memoVar.flags | Method | Accessor,
443+
info = MethodType(argType :: Nil, defn.UnitType),
444+
coord = call.span
445+
).entered
446+
else memoVar
436447
val memoRef = ref(memoVar).withSpan(call.span)
437448
val cond = If(
438449
memoRef.select(defn.Any_==).appliedTo(Literal(Constant(null))),

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,19 +2122,18 @@ class Typer extends Namer
21222122
case Some(xtree) =>
21232123
traverse(xtree :: rest)
21242124
case none =>
2125-
val memoCacheCount = MemoCacheName.currentCount()
2125+
val memoCacheCount = MemoCacheName.currentCount(nme.memo)
21262126
typed(mdef) match {
21272127
case mdef1: DefDef if Inliner.hasBodyToInline(mdef1.symbol) =>
21282128
buf += inlineExpansion(mdef1)
21292129
// replace body with expansion, because it will be used as inlined body
21302130
// from separately compiled files - the original BodyAnnotation is not kept.
21312131
case mdef1 =>
2132-
import untpd.modsDeco
2133-
mdef match {
2134-
case mdef: untpd.TypeDef if mdef.mods.isEnumClass =>
2132+
mdef1 match {
2133+
case mdef1: TypeDef if mdef1.symbol.flags.is(Enum, butNot = Case) =>
21352134
enumContexts(mdef1.symbol) = ctx
2136-
case _: untpd.DefDef if MemoCacheName.currentCount() != memoCacheCount =>
2137-
buf ++= Inliner.memoCacheDefs(mdef1)
2135+
case mdef1: DefDef if MemoCacheName.currentCount(nme.memo) != memoCacheCount =>
2136+
buf ++= Inliner.memoCacheDefs(mdef1.rhs)
21382137
case _ =>
21392138
}
21402139
if (!mdef1.isEmpty) // clashing synthetic case methods are converted to empty trees

tests/run/memoTest.check

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
computing inner
2+
computing f
3+
1
4+
1
5+
computing f
6+
1
7+
1

tests/run/memoTest.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,22 @@ object Test extends App {
1313
assert(foo(1) + foo(2) == 4)
1414
assert(bar(1) + bar(2) == 4)
1515

16+
trait T {
17+
def x: Int
18+
def y: Int = memo {
19+
def inner = memo {
20+
println("computing inner");
21+
x * x
22+
}
23+
inner + inner
24+
}
25+
}
26+
val t = new T {
27+
def x = 3
28+
assert(y == 18)
29+
}
30+
assert(t.y == 18)
31+
1632
class Context(val n: Int)
1733
def f(c: Context): Context = {
1834
println("computing f")

0 commit comments

Comments
 (0)