Skip to content

Commit 3614108

Browse files
committed
Make cache part of state
1 parent 7bc5f89 commit 3614108

File tree

1 file changed

+43
-21
lines changed

1 file changed

+43
-21
lines changed

compiler/src/dotty/tools/dotc/transform/init/Semantic.scala

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,31 @@ object Semantic {
267267
* statement body. Macros may also create incorrect locations.
268268
*
269269
*/
270-
type Cache = mutable.Map[Value, EqHashMap[Tree, Value]]
271-
val cache: Cache = mutable.Map.empty[Value, EqHashMap[Tree, Value]]
270+
271+
class Cache(val in: Cache.CacheIn, var out: Cache.CacheOut) {
272+
var changed: Boolean = false
273+
}
274+
275+
object Cache {
276+
opaque type CacheIn = mutable.Map[Value, EqHashMap[Tree, Value]]
277+
opaque type CacheOut = mutable.Map[Value, EqHashMap[Tree, Value]]
278+
279+
val empty: Cache = new Cache(mutable.Map.empty, mutable.Map.empty)
280+
281+
extension (cache: CacheIn | CacheOut)
282+
def contains(value: Value, expr: Tree): Boolean = cache.contains(value) && cache(value).contains(expr)
283+
def get(value: Value, expr: Tree): Value = cache(value)(expr)
284+
def put(value: Value, expr: Tree, result: Value): Unit = {
285+
val innerMap = cache.getOrElseUpdate(value, new EqHashMap[Tree, Value])
286+
innerMap(expr) = result
287+
}
288+
end extension
289+
}
290+
291+
import Cache._
292+
293+
inline def cache(using c: Cache): Cache = c
294+
272295

273296
/** Result of abstract interpretation */
274297
case class Result(value: Value, errors: Seq[Error]) {
@@ -293,9 +316,10 @@ object Semantic {
293316

294317
// ----- State --------------------------------------------
295318
/** Global state of the checker */
296-
class State(val heap: Heap, val workList: WorkList)
319+
class State(val cache: Cache, val heap: Heap, val workList: WorkList)
297320

298321
given (using s: State): Heap = s.heap
322+
given (using s: State): Cache = s.cache
299323
given (using s: State): WorkList = s.workList
300324

301325
/** The state that threads through the interpreter */
@@ -368,7 +392,7 @@ object Semantic {
368392
if target.is(Flags.Lazy) then
369393
given Trace = trace1
370394
val rhs = target.defTree.asInstanceOf[ValDef].rhs
371-
eval(rhs, ref, target.owner.asClass, cacheResult = true)
395+
eval(rhs, ref, target.owner.asClass)
372396
else
373397
val obj = ref.objekt
374398
if obj.hasField(target) then
@@ -382,7 +406,7 @@ object Semantic {
382406
Result(Hot, Nil)
383407
else if target.hasSource then
384408
val rhs = target.defTree.asInstanceOf[ValOrDefDef].rhs
385-
eval(rhs, ref, target.owner.asClass, cacheResult = true)
409+
eval(rhs, ref, target.owner.asClass)
386410
else
387411
val error = CallUnknown(field, source, trace.toVector)
388412
Result(Hot, error :: Nil)
@@ -434,7 +458,7 @@ object Semantic {
434458
val env2 = Env(ddef, args.map(_.value).widenArgs)
435459
// normal method call
436460
withEnv(if isLocal then env else Env.empty) {
437-
eval(ddef.rhs, ref, cls, cacheResult = true) ++ checkArgs
461+
eval(ddef.rhs, ref, cls) ++ checkArgs
438462
}
439463
else if ref.canIgnoreMethodCall(target) then
440464
Result(Hot, Nil)
@@ -455,7 +479,7 @@ object Semantic {
455479
if meth.name.toString == "tupled" then Result(value, Nil) // a call like `fun.tupled`
456480
else
457481
withEnv(env) {
458-
eval(body, thisV, klass, cacheResult = true) ++ checkArgs
482+
eval(body, thisV, klass) ++ checkArgs
459483
}
460484

461485
case RefSet(refs) =>
@@ -482,11 +506,11 @@ object Semantic {
482506
if ctor.isPrimaryConstructor then
483507
given Env = env2
484508
val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
485-
val res = withTrace(trace.add(cls.defTree)) { eval(tpl, ref, cls, cacheResult = true) }
509+
val res = withTrace(trace.add(cls.defTree)) { eval(tpl, ref, cls) }
486510
Result(ref, res.errors)
487511
else
488512
given Env = env2
489-
eval(ddef.rhs, ref, cls, cacheResult = true)
513+
eval(ddef.rhs, ref, cls)
490514
else if ref.canIgnoreMethodCall(ctor) then
491515
Result(Hot, Nil)
492516
else
@@ -792,7 +816,7 @@ object Semantic {
792816
* }
793817
*/
794818
def withInitialState[T](work: State ?=> T): T = {
795-
val initialState = State(Heap.empty, new WorkList)
819+
val initialState = State(Cache.empty, Heap.empty, new WorkList)
796820
work(using initialState)
797821
}
798822

@@ -818,17 +842,15 @@ object Semantic {
818842
*
819843
* This method only handles cache logic and delegates the work to `cases`.
820844
*/
821-
def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) {
822-
val innerMap = cache.getOrElseUpdate(thisV, new EqHashMap[Tree, Value])
823-
if (innerMap.contains(expr)) Result(innerMap(expr), Errors.empty)
845+
def eval(expr: Tree, thisV: Ref, klass: ClassSymbol): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) {
846+
if (cache.out.contains(thisV, expr)) Result(cache.out.get(thisV, expr), Errors.empty)
824847
else {
825-
// no need to compute fix-point, because
826-
// 1. the result is decided by `cfg` for a legal program
827-
// (heap change is irrelevant thanks to monotonicity)
828-
// 2. errors will have been reported for an illegal program
829-
innerMap(expr) = Hot
848+
val assumeValue = if (cache.in.contains(thisV, expr)) cache.in.get(thisV, expr) else Hot
849+
cache.out.put(thisV, expr, assumeValue)
830850
val res = cases(expr, thisV, klass)
831-
if cacheResult then innerMap(expr) = res.value else innerMap.remove(expr)
851+
if res.value != assumeValue then
852+
cache.changed = true
853+
cache.out.put(thisV, expr, res.value)
832854
res
833855
}
834856
}
@@ -1005,7 +1027,7 @@ object Semantic {
10051027
case vdef : ValDef =>
10061028
// local val definition
10071029
// TODO: support explicit @cold annotation for local definitions
1008-
eval(vdef.rhs, thisV, klass, cacheResult = true)
1030+
eval(vdef.rhs, thisV, klass)
10091031

10101032
case ddef : DefDef =>
10111033
// local method
@@ -1225,7 +1247,7 @@ object Semantic {
12251247
tpl.body.foreach {
12261248
case vdef : ValDef if !vdef.symbol.is(Flags.Lazy) && !vdef.rhs.isEmpty =>
12271249
given Env = Env.empty
1228-
val res = eval(vdef.rhs, thisV, klass, cacheResult = true)
1250+
val res = eval(vdef.rhs, thisV, klass)
12291251
errorBuffer ++= res.errors
12301252
thisV.updateField(vdef.symbol, res.value)
12311253
fieldsChanged = true

0 commit comments

Comments
 (0)