Skip to content

Implement memoization #6887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,17 @@ object desugar {

// ----- Desugar methods -------------------------------------------------

def setterNeeded(flags: FlagSet, owner: Symbol) given Context =
flags.is(Mutable) && owner.isClass && (!flags.isAllOf(PrivateLocal) || owner.is(Trait))

/** var x: Int = expr
* ==>
* def x: Int = expr
* def x_=($1: <TypeTree()>): Unit = ()
*/
def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = {
val vdef @ ValDef(name, tpt, rhs) = transformQuotedPatternName(vdef0)
val mods = vdef.mods
val setterNeeded =
mods.is(Mutable) && ctx.owner.isClass && (!mods.isAllOf(PrivateLocal) || ctx.owner.is(Trait))
if (setterNeeded) {
// TODO: copy of vdef as getter needed?
// val getter = ValDef(mods, name, tpt, rhs) withPos vdef.pos?
// right now vdef maps via expandedTree to a thicket which concerns itself.
// I don't see a problem with that but if there is one we can avoid it by making a copy here.
if (setterNeeded(vdef.mods.flags, ctx.owner)) {
val setterParam = makeSyntheticParameter(tpt = SetterParamTree().watching(vdef))
// The rhs gets filled in later, when field is generated and getter has parameters (see Memoize miniphase)
val setterRhs = if (vdef.rhs.isEmpty) EmptyTree else unitLiteral
Expand All @@ -168,7 +164,7 @@ object desugar {
vparamss = (setterParam :: Nil) :: Nil,
tpt = TypeTree(defn.UnitType),
rhs = setterRhs
).withMods((mods | Accessor) &~ (CaseAccessor | GivenOrImplicit | Lazy))
).withMods((vdef.mods | Accessor) &~ (CaseAccessor | GivenOrImplicit | Lazy))
Thicket(vdef, setter)
}
else vdef
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class Definitions {
@threadUnsafe lazy val Compiletime_constValue : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("constValue"))
@threadUnsafe lazy val Compiletime_constValueOpt: SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("constValueOpt"))
@threadUnsafe lazy val Compiletime_code : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("code"))
@threadUnsafe lazy val Compiletime_memo : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("memo"))

/** The `scalaShadowing` package is used to safely modify classes and
* objects in scala so that they can be used from dotty. They will
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/NameKinds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ object NameKinds {
safePrefix + info.num
}

def currentCount(prefix: TermName = EmptyTermName) given (ctx: Context): Int =
ctx.freshNames.currentCount(prefix, this)

/** Generate fresh unique term name of this kind with given prefix name */
def fresh(prefix: TermName = EmptyTermName)(implicit ctx: Context): TermName =
ctx.freshNames.newName(prefix, this)
Expand Down Expand Up @@ -296,6 +299,7 @@ object NameKinds {
val UniqueInlineName: UniqueNameKind = new UniqueNameKind("$i")
val InlineScrutineeName: UniqueNameKind = new UniqueNameKind("$scrutinee")
val InlineBinderName: UniqueNameKind = new UniqueNameKind("$elem")
val MemoCacheName: UniqueNameKind = new UniqueNameKind("$cache")

/** A kind of unique extension methods; Unlike other unique names, these can be
* unmangled.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ object StdNames {
val materializeClassTag: N = "materializeClassTag"
val materializeWeakTypeTag: N = "materializeWeakTypeTag"
val materializeTypeTag: N = "materializeTypeTag"
val memo: N = "memo"
val mirror : N = "mirror"
val moduleClass : N = "moduleClass"
val name: N = "name"
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ trait SymDenotations { this: Context =>
}
}

/** Configurable: Accept stale symbol with warning if in IDE */
def staleOK: Boolean = Config.ignoreStaleInIDE && mode.is(Mode.Interactive)
/** Configurable: Accept stale symbol with warning if in IDE
* Always accept stale symbols when testing pickling.
*/
def staleOK: Boolean =
Config.ignoreStaleInIDE && mode.is(Mode.Interactive) ||
settings.YtestPickler.value

/** Possibly accept stale symbol with warning if in IDE */
def acceptStale(denot: SingleDenotation): Boolean =
Expand Down
25 changes: 14 additions & 11 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,10 @@ object Symbols {

/** This symbol entered into owner's scope (owner must be a class). */
final def entered(implicit ctx: Context): this.type = {
assert(this.owner.isClass, s"symbol ($this) entered the scope of non-class owner ${this.owner}") // !!! DEBUG
this.owner.asClass.enter(this)
if (this.is(Module)) this.owner.asClass.enter(this.moduleClass)
if (this.owner.isClass) {
this.owner.asClass.enter(this)
if (this.is(Module)) this.owner.asClass.enter(this.moduleClass)
}
this
}

Expand All @@ -566,14 +567,16 @@ object Symbols {
*/
def enteredAfter(phase: DenotTransformer)(implicit ctx: Context): this.type =
if (ctx.phaseId != phase.next.id) enteredAfter(phase)(ctx.withPhase(phase.next))
else {
if (this.owner.is(Package)) {
denot.validFor |= InitialPeriod
if (this.is(Module)) this.moduleClass.validFor |= InitialPeriod
}
else this.owner.asClass.ensureFreshScopeAfter(phase)
assert(isPrivate || phase.changesMembers, i"$this entered in ${this.owner} at undeclared phase $phase")
entered
else this.owner match {
case owner: ClassSymbol =>
if (owner.is(Package)) {
denot.validFor |= InitialPeriod
if (this.is(Module)) this.moduleClass.validFor |= InitialPeriod
}
else owner.ensureFreshScopeAfter(phase)
assert(isPrivate || phase.changesMembers, i"$this entered in $owner at undeclared phase $phase")
entered
case _ => this
}

/** Remove symbol from scope of owning class */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class CacheAliasImplicits extends MiniPhase with IdentityDenotTransformer { this
val cacheFlags = if (ctx.owner.isClass) Private | Local | Mutable else Mutable
val cacheSym =
ctx.newSymbol(ctx.owner, CacheName(tree.name), cacheFlags, rhsType, coord = sym.coord)
if (ctx.owner.isClass) cacheSym.enteredAfter(thisPhase)
.enteredAfter(thisPhase)
val cacheDef = ValDef(cacheSym, tpd.defaultValue(rhsType))
val cachingDef = cpy.DefDef(tree)(rhs =
Block(
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase
val argTypeWrtConstr = argType.subst(origParams, allParamRefs(constr.info))
// argType with references to paramRefs of the primary constructor instead of
// local parameter accessors
val meth = ctx.newSymbol(
ctx.newSymbol(
owner = methOwner,
name = SuperArgName.fresh(cls.name.toTermName),
flags = Synthetic | Private | Method | staticFlag,
info = replaceResult(constr.info, argTypeWrtConstr),
coord = constr.coord)
if (methOwner.isClass) meth.enteredAfter(thisPhase) else meth
coord = constr.coord
).enteredAfter(thisPhase)
}

/** Type of a reference implies that it needs to be hoisted */
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/LambdaLift.scala
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,9 @@ object LambdaLift {
proxyMap(owner) = {
for (fv <- freeValues.toList) yield {
val proxyName = newName(fv)
val proxy = ctx.newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord)
if (owner.isClass) proxy.enteredAfter(thisPhase)
val proxy =
ctx.newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord)
.enteredAfter(thisPhase)
(fv, proxy)
}
}.toMap
Expand Down
67 changes: 66 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import StdNames._
import transform.SymUtils._
import Contexts.Context
import Names.{Name, TermName}
import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName}
import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName, MemoCacheName}
import NameOps._
import ProtoTypes.selectionProto
import SymDenotations.SymDenotation
import Inferencing.fullyDefinedType
Expand Down Expand Up @@ -188,6 +189,27 @@ object Inliner {
if (callSym.is(Macro)) ref(callSym.topLevelClass.owner).select(callSym.topLevelClass.name).withSpan(pos.span)
else Ident(callSym.topLevelClass.typeRef).withSpan(pos.span)
}

/** For every occurrence of a memo cache symbol `memo$N` of type `T_N` in `tree`,
* an assignment `val memo$N: T_N = null`
*/
def memoCacheDefs(tree: Tree) given Context: List[ValOrDefDef] = {
object memoRefs extends TreeTraverser {
val syms = new mutable.LinkedHashSet[TermSymbol]
def traverse(tree: Tree) given Context = tree match {
case tree: RefTree if tree.symbol.name.is(MemoCacheName) =>
syms += tree.symbol.asTerm
case _: DefDef =>
// don't traverse deeper; nested memo caches go next to nested method
case _ =>
traverseChildren(tree)
}
}
memoRefs.traverse(tree)
for sym <- memoRefs.syms.toList yield
(if (sym.isSetter) DefDef(sym, _ => Literal(Constant(())))
else ValDef(sym, Literal(Constant(null)))).withSpan(sym.span)
}
}

/** Produces an inlined version of `call` via its `inlined` method.
Expand Down Expand Up @@ -392,6 +414,47 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
case _ => EmptyTree
}

/** The expansion of `memo(op)` where `op: T` is:
*
* { if (memo$N == null) memo$N_=(op); $memo.asInstanceOf[T] }
*
* This creates as a side effect a memo cache symbol $memo$N` of type `T | Null`.
* TODO: Restrict this to non-null types, once nullability checking is in.
*/
def memoized: Tree = {
val currentOwner = ctx.owner.skipWeakOwner
if (currentOwner.isRealMethod) {
val cacheOwner = currentOwner.owner
val argType = callTypeArgs.head.tpe
val memoVar = ctx.newSymbol(
owner = cacheOwner,
name = MemoCacheName.fresh(nme.memo),
flags =
if (cacheOwner.isTerm) Synthetic | Mutable
else Synthetic | Mutable | Private | Local,
info = OrType(argType, defn.NullType),
coord = call.span).entered
val memoSetter =
if (desugar.setterNeeded(memoVar.flags, cacheOwner))
ctx.newSymbol(
owner = cacheOwner,
name = memoVar.name.setterName,
flags = memoVar.flags | Method | Accessor,
info = MethodType(argType :: Nil, defn.UnitType),
coord = call.span
).entered
else memoVar
val memoRef = ref(memoVar).withSpan(call.span)
val cond = If(
memoRef.select(defn.Any_==).appliedTo(Literal(Constant(null))),
ref(memoSetter).withSpan(call.span).becomes(callValueArgss.head.head),
Literal(Constant(())))
val expr = memoRef.cast(argType)
Block(cond :: Nil, expr)
}
else errorTree(call, em"""memo(...) outside method""")
}

/** The Inlined node representing the inlined call */
def inlined(sourcePos: SourcePosition): Tree = {

Expand All @@ -408,6 +471,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil)
)
}
else if (inlinedMethod == defn.Compiletime_memo)
return memoized

// Compute bindings for all parameters, appending them to bindingsBuf
computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss)
Expand Down
17 changes: 9 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2122,20 +2122,21 @@ class Typer extends Namer
case Some(xtree) =>
traverse(xtree :: rest)
case none =>
val memoCacheCount = MemoCacheName.currentCount(nme.memo)
typed(mdef) match {
case mdef1: DefDef if Inliner.hasBodyToInline(mdef1.symbol) =>
buf += inlineExpansion(mdef1)
// replace body with expansion, because it will be used as inlined body
// from separately compiled files - the original BodyAnnotation is not kept.
case mdef1: DefDef if MemoCacheName.currentCount(nme.memo) != memoCacheCount =>
buf ++= Inliner.memoCacheDefs(mdef1.rhs) += mdef1
case mdef1: TypeDef if mdef1.symbol.is(Enum, butNot = Case) =>
enumContexts(mdef1.symbol) = ctx
buf += mdef1
case EmptyTree =>
// clashing synthetic case methods are converted to empty trees, drop them here
case mdef1 =>
import untpd.modsDeco
mdef match {
case mdef: untpd.TypeDef if mdef.mods.isEnumClass =>
enumContexts(mdef1.symbol) = ctx
case _ =>
}
if (!mdef1.isEmpty) // clashing synthetic case methods are converted to empty trees
buf += mdef1
buf += mdef1
}
traverse(rest)
}
Expand Down
11 changes: 9 additions & 2 deletions compiler/src/dotty/tools/dotc/util/FreshNameCreator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,27 @@ import core.StdNames.str

abstract class FreshNameCreator {
def newName(prefix: TermName, unique: UniqueNameKind): TermName
def currentCount(prefix: TermName, unique: UniqueNameKind): Int
}

object FreshNameCreator {
class Default extends FreshNameCreator {
protected var counter: Int = 0
protected val counters: mutable.Map[String, Int] = mutable.AnyRefMap() withDefaultValue 0

private def keyFor(prefix: TermName, unique: UniqueNameKind) =
str.sanitize(prefix.toString) + unique.separator

/** The current counter for the given combination of `prefix` and `unique` */
def currentCount(prefix: TermName, unique: UniqueNameKind): Int =
counters(keyFor(prefix, unique))

/**
* Create a fresh name with the given prefix. It is guaranteed
* that the returned name has never been returned by a previous
* call to this function (provided the prefix does not end in a digit).
*/
def newName(prefix: TermName, unique: UniqueNameKind): TermName = {
val key = str.sanitize(prefix.toString) + unique.separator
val key = keyFor(prefix, unique)
counters(key) += 1
prefix.derived(unique.NumberedInfo(counters(key)))
}
Expand Down
67 changes: 67 additions & 0 deletions docs/docs/reference/metaprogramming/inline.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,73 @@ inline def fail(p1: => Any) = {
fail(indentity("foo")) // error: failed on: indentity("foo")
```

#### `memo`

The `memo` method is used to avoid repeated evaluation of subcomputations.
Example:
```
type T = ...
class C(x: T) {
def costly(x: T): Int = ???
def f(y: Int) = memo(costly(x)) * y
}
```
Let's assume that `costly` is a pure function that is expensive to compute. If `f` was defined
like this:
```
def f(y: Int) = costly(x) * y
```
the `costly(x)` subexpression would be recomputed each time `f` was called, even though
its result is the same each time. With the addition of `memo(...)` the subexpression
in the parentheses is computed only the first time and is cached for subsequent recalculuations.
The memoized program expands to the following code:
```
class C(x: T) {
Copy link
Contributor

@soronpo soronpo Jul 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing [T] (and in all other examples)

def costly(x: T): Int = ???
private[this] var memo$1: T | Null = null
def f(y: Int) = {
if (memo$1 == null) memo$1 = costly(x)
memo$1.asInstanceOf[T]
} * y
}
```
The fine-print behind this expansion is:

- The caching variable is placed next to the enclosing method (`f` in this case).
- Its type is the union of the type of the cached expression and `Null`.
- Its inital value is `null`.
- A `memo(op)` call is expanded to code that tests whether the cached variable is
null, in which case it reassignes the variable with the result of evaluating `op`.
The value of `memo(op)` is the value of the cached variable after this conditional assignment.

In simple scenarios the call to `memo` is equivalent to using `lazy val`. For instance
the example program above could be simulated like this:
```
class C(x: T) {
def costly(x: T): Int = ???
@threadunsafe private[this] lazy val cached = costly(x)
def f(y: Int) = cached * y
}
```
The advantage of using `memo` over lazy vals is that it's more concise. But `memo` could also be
used in scenarios where lazy vals are not suitable. For instance, let's assume
that the methods in class `C` above also need a given `Context` parameter.
```
class C(x: T) {
def costly(x: T) given Context: Int = ???
def f(y: Int) given (c: Context) = memo(costly(x) given c) * y
}
```
Now, we cannot simply pull out the computation `costly(x) given c` into a lazy val since
it depends on the parameter `c` which is only available inside `f`. On the other hand,
it's much harder to argue that the `memo` solution is correct. One possible scenario
is that we fully intend to capture and reuse only the first computation of `costly(x)`.
Another possible scenario is that we do want `memo` to be semantically invisible, used
for optimization only, but that we convince ourselves that `costly(x) given c` would return
the same value no matter what context `c` is passed to `f`. That's a much harder argument
to make, but sometimes we can derive this from the global architecture of the system we are
dealing with.

## Implicit Matches

It is foreseen that many areas of typelevel programming can be done with rewrite
Expand Down
2 changes: 2 additions & 0 deletions library/src/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ package object compiletime {
inline def constValue[T]: T = ???

type S[X <: Int] <: Int

inline def memo[T](op: => T): T = ???
}
Loading