diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 17873cf62ecb..52bc778d8952 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -144,6 +144,9 @@ object desugar { // ----- Desugar methods ------------------------------------------------- + def setterNeeded(flags: FlagSet, owner: Symbol) given Context = + flags.is(Mutable) && owner.isClass && (!flags.isAllOf(PrivateLocal) || owner.is(Trait)) + /** var x: Int = expr * ==> * def x: Int = expr @@ -151,14 +154,7 @@ object desugar { */ def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = { val vdef @ ValDef(name, tpt, rhs) = transformQuotedPatternName(vdef0) - val mods = vdef.mods - val setterNeeded = - mods.is(Mutable) && ctx.owner.isClass && (!mods.isAllOf(PrivateLocal) || ctx.owner.is(Trait)) - if (setterNeeded) { - // TODO: copy of vdef as getter needed? - // val getter = ValDef(mods, name, tpt, rhs) withPos vdef.pos? - // right now vdef maps via expandedTree to a thicket which concerns itself. - // I don't see a problem with that but if there is one we can avoid it by making a copy here. + if (setterNeeded(vdef.mods.flags, ctx.owner)) { val setterParam = makeSyntheticParameter(tpt = SetterParamTree().watching(vdef)) // The rhs gets filled in later, when field is generated and getter has parameters (see Memoize miniphase) val setterRhs = if (vdef.rhs.isEmpty) EmptyTree else unitLiteral @@ -168,7 +164,7 @@ object desugar { vparamss = (setterParam :: Nil) :: Nil, tpt = TypeTree(defn.UnitType), rhs = setterRhs - ).withMods((mods | Accessor) &~ (CaseAccessor | GivenOrImplicit | Lazy)) + ).withMods((vdef.mods | Accessor) &~ (CaseAccessor | GivenOrImplicit | Lazy)) Thicket(vdef, setter) } else vdef diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 33af87223344..5694667e4ad2 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -237,6 +237,7 @@ class Definitions { @threadUnsafe lazy val Compiletime_constValue : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("constValue")) @threadUnsafe lazy val Compiletime_constValueOpt: SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("constValueOpt")) @threadUnsafe lazy val Compiletime_code : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("code")) + @threadUnsafe lazy val Compiletime_memo : SymbolPerRun = perRunSym(CompiletimePackageObject.requiredMethodRef("memo")) /** The `scalaShadowing` package is used to safely modify classes and * objects in scala so that they can be used from dotty. They will diff --git a/compiler/src/dotty/tools/dotc/core/NameKinds.scala b/compiler/src/dotty/tools/dotc/core/NameKinds.scala index ec0a6be801b3..fca24d032fdb 100644 --- a/compiler/src/dotty/tools/dotc/core/NameKinds.scala +++ b/compiler/src/dotty/tools/dotc/core/NameKinds.scala @@ -213,6 +213,9 @@ object NameKinds { safePrefix + info.num } + def currentCount(prefix: TermName = EmptyTermName) given (ctx: Context): Int = + ctx.freshNames.currentCount(prefix, this) + /** Generate fresh unique term name of this kind with given prefix name */ def fresh(prefix: TermName = EmptyTermName)(implicit ctx: Context): TermName = ctx.freshNames.newName(prefix, this) @@ -296,6 +299,7 @@ object NameKinds { val UniqueInlineName: UniqueNameKind = new UniqueNameKind("$i") val InlineScrutineeName: UniqueNameKind = new UniqueNameKind("$scrutinee") val InlineBinderName: UniqueNameKind = new UniqueNameKind("$elem") + val MemoCacheName: UniqueNameKind = new UniqueNameKind("$cache") /** A kind of unique extension methods; Unlike other unique names, these can be * unmangled. diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index e331c3a84dc1..f7dd2da26ace 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -482,6 +482,7 @@ object StdNames { val materializeClassTag: N = "materializeClassTag" val materializeWeakTypeTag: N = "materializeWeakTypeTag" val materializeTypeTag: N = "materializeTypeTag" + val memo: N = "memo" val mirror : N = "mirror" val moduleClass : N = "moduleClass" val name: N = "name" diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index b40468716616..58854adf8ec5 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -102,8 +102,12 @@ trait SymDenotations { this: Context => } } - /** Configurable: Accept stale symbol with warning if in IDE */ - def staleOK: Boolean = Config.ignoreStaleInIDE && mode.is(Mode.Interactive) + /** Configurable: Accept stale symbol with warning if in IDE + * Always accept stale symbols when testing pickling. + */ + def staleOK: Boolean = + Config.ignoreStaleInIDE && mode.is(Mode.Interactive) || + settings.YtestPickler.value /** Possibly accept stale symbol with warning if in IDE */ def acceptStale(denot: SingleDenotation): Boolean = diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index 24d2c188a419..3aaf8cb80838 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -553,9 +553,10 @@ object Symbols { /** This symbol entered into owner's scope (owner must be a class). */ final def entered(implicit ctx: Context): this.type = { - assert(this.owner.isClass, s"symbol ($this) entered the scope of non-class owner ${this.owner}") // !!! DEBUG - this.owner.asClass.enter(this) - if (this.is(Module)) this.owner.asClass.enter(this.moduleClass) + if (this.owner.isClass) { + this.owner.asClass.enter(this) + if (this.is(Module)) this.owner.asClass.enter(this.moduleClass) + } this } @@ -566,14 +567,16 @@ object Symbols { */ def enteredAfter(phase: DenotTransformer)(implicit ctx: Context): this.type = if (ctx.phaseId != phase.next.id) enteredAfter(phase)(ctx.withPhase(phase.next)) - else { - if (this.owner.is(Package)) { - denot.validFor |= InitialPeriod - if (this.is(Module)) this.moduleClass.validFor |= InitialPeriod - } - else this.owner.asClass.ensureFreshScopeAfter(phase) - assert(isPrivate || phase.changesMembers, i"$this entered in ${this.owner} at undeclared phase $phase") - entered + else this.owner match { + case owner: ClassSymbol => + if (owner.is(Package)) { + denot.validFor |= InitialPeriod + if (this.is(Module)) this.moduleClass.validFor |= InitialPeriod + } + else owner.ensureFreshScopeAfter(phase) + assert(isPrivate || phase.changesMembers, i"$this entered in $owner at undeclared phase $phase") + entered + case _ => this } /** Remove symbol from scope of owning class */ diff --git a/compiler/src/dotty/tools/dotc/transform/CacheAliasImplicits.scala b/compiler/src/dotty/tools/dotc/transform/CacheAliasImplicits.scala index 29eae5107ec7..9cf5fd216388 100644 --- a/compiler/src/dotty/tools/dotc/transform/CacheAliasImplicits.scala +++ b/compiler/src/dotty/tools/dotc/transform/CacheAliasImplicits.scala @@ -82,7 +82,7 @@ class CacheAliasImplicits extends MiniPhase with IdentityDenotTransformer { this val cacheFlags = if (ctx.owner.isClass) Private | Local | Mutable else Mutable val cacheSym = ctx.newSymbol(ctx.owner, CacheName(tree.name), cacheFlags, rhsType, coord = sym.coord) - if (ctx.owner.isClass) cacheSym.enteredAfter(thisPhase) + .enteredAfter(thisPhase) val cacheDef = ValDef(cacheSym, tpd.defaultValue(rhsType)) val cachingDef = cpy.DefDef(tree)(rhs = Block( diff --git a/compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala b/compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala index 302d2330d056..d72f0895ca49 100644 --- a/compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala +++ b/compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala @@ -91,13 +91,13 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase val argTypeWrtConstr = argType.subst(origParams, allParamRefs(constr.info)) // argType with references to paramRefs of the primary constructor instead of // local parameter accessors - val meth = ctx.newSymbol( + ctx.newSymbol( owner = methOwner, name = SuperArgName.fresh(cls.name.toTermName), flags = Synthetic | Private | Method | staticFlag, info = replaceResult(constr.info, argTypeWrtConstr), - coord = constr.coord) - if (methOwner.isClass) meth.enteredAfter(thisPhase) else meth + coord = constr.coord + ).enteredAfter(thisPhase) } /** Type of a reference implies that it needs to be hoisted */ diff --git a/compiler/src/dotty/tools/dotc/transform/LambdaLift.scala b/compiler/src/dotty/tools/dotc/transform/LambdaLift.scala index 4bd73238797b..548f594865b1 100644 --- a/compiler/src/dotty/tools/dotc/transform/LambdaLift.scala +++ b/compiler/src/dotty/tools/dotc/transform/LambdaLift.scala @@ -298,8 +298,9 @@ object LambdaLift { proxyMap(owner) = { for (fv <- freeValues.toList) yield { val proxyName = newName(fv) - val proxy = ctx.newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord) - if (owner.isClass) proxy.enteredAfter(thisPhase) + val proxy = + ctx.newSymbol(owner, proxyName.asTermName, newFlags, fv.info, coord = fv.coord) + .enteredAfter(thisPhase) (fv, proxy) } }.toMap diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index bf282476389a..6dafc0fa6932 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -15,7 +15,8 @@ import StdNames._ import transform.SymUtils._ import Contexts.Context import Names.{Name, TermName} -import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName} +import NameKinds.{InlineAccessorName, InlineBinderName, InlineScrutineeName, MemoCacheName} +import NameOps._ import ProtoTypes.selectionProto import SymDenotations.SymDenotation import Inferencing.fullyDefinedType @@ -188,6 +189,27 @@ object Inliner { if (callSym.is(Macro)) ref(callSym.topLevelClass.owner).select(callSym.topLevelClass.name).withSpan(pos.span) else Ident(callSym.topLevelClass.typeRef).withSpan(pos.span) } + + /** For every occurrence of a memo cache symbol `memo$N` of type `T_N` in `tree`, + * an assignment `val memo$N: T_N = null` + */ + def memoCacheDefs(tree: Tree) given Context: List[ValOrDefDef] = { + object memoRefs extends TreeTraverser { + val syms = new mutable.LinkedHashSet[TermSymbol] + def traverse(tree: Tree) given Context = tree match { + case tree: RefTree if tree.symbol.name.is(MemoCacheName) => + syms += tree.symbol.asTerm + case _: DefDef => + // don't traverse deeper; nested memo caches go next to nested method + case _ => + traverseChildren(tree) + } + } + memoRefs.traverse(tree) + for sym <- memoRefs.syms.toList yield + (if (sym.isSetter) DefDef(sym, _ => Literal(Constant(()))) + else ValDef(sym, Literal(Constant(null)))).withSpan(sym.span) + } } /** Produces an inlined version of `call` via its `inlined` method. @@ -392,6 +414,47 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { case _ => EmptyTree } + /** The expansion of `memo(op)` where `op: T` is: + * + * { if (memo$N == null) memo$N_=(op); $memo.asInstanceOf[T] } + * + * This creates as a side effect a memo cache symbol $memo$N` of type `T | Null`. + * TODO: Restrict this to non-null types, once nullability checking is in. + */ + def memoized: Tree = { + val currentOwner = ctx.owner.skipWeakOwner + if (currentOwner.isRealMethod) { + val cacheOwner = currentOwner.owner + val argType = callTypeArgs.head.tpe + val memoVar = ctx.newSymbol( + owner = cacheOwner, + name = MemoCacheName.fresh(nme.memo), + flags = + if (cacheOwner.isTerm) Synthetic | Mutable + else Synthetic | Mutable | Private | Local, + info = OrType(argType, defn.NullType), + coord = call.span).entered + val memoSetter = + if (desugar.setterNeeded(memoVar.flags, cacheOwner)) + ctx.newSymbol( + owner = cacheOwner, + name = memoVar.name.setterName, + flags = memoVar.flags | Method | Accessor, + info = MethodType(argType :: Nil, defn.UnitType), + coord = call.span + ).entered + else memoVar + val memoRef = ref(memoVar).withSpan(call.span) + val cond = If( + memoRef.select(defn.Any_==).appliedTo(Literal(Constant(null))), + ref(memoSetter).withSpan(call.span).becomes(callValueArgss.head.head), + Literal(Constant(()))) + val expr = memoRef.cast(argType) + Block(cond :: Nil, expr) + } + else errorTree(call, em"""memo(...) outside method""") + } + /** The Inlined node representing the inlined call */ def inlined(sourcePos: SourcePosition): Tree = { @@ -408,6 +471,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { else New(defn.SomeClass.typeRef.appliedTo(constVal.tpe), constVal :: Nil) ) } + else if (inlinedMethod == defn.Compiletime_memo) + return memoized // Compute bindings for all parameters, appending them to bindingsBuf computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 00289794ee73..325409d02372 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2122,20 +2122,21 @@ class Typer extends Namer case Some(xtree) => traverse(xtree :: rest) case none => + val memoCacheCount = MemoCacheName.currentCount(nme.memo) typed(mdef) match { case mdef1: DefDef if Inliner.hasBodyToInline(mdef1.symbol) => buf += inlineExpansion(mdef1) // replace body with expansion, because it will be used as inlined body // from separately compiled files - the original BodyAnnotation is not kept. + case mdef1: DefDef if MemoCacheName.currentCount(nme.memo) != memoCacheCount => + buf ++= Inliner.memoCacheDefs(mdef1.rhs) += mdef1 + case mdef1: TypeDef if mdef1.symbol.is(Enum, butNot = Case) => + enumContexts(mdef1.symbol) = ctx + buf += mdef1 + case EmptyTree => + // clashing synthetic case methods are converted to empty trees, drop them here case mdef1 => - import untpd.modsDeco - mdef match { - case mdef: untpd.TypeDef if mdef.mods.isEnumClass => - enumContexts(mdef1.symbol) = ctx - case _ => - } - if (!mdef1.isEmpty) // clashing synthetic case methods are converted to empty trees - buf += mdef1 + buf += mdef1 } traverse(rest) } diff --git a/compiler/src/dotty/tools/dotc/util/FreshNameCreator.scala b/compiler/src/dotty/tools/dotc/util/FreshNameCreator.scala index f3375028c95f..0c79f926fcc6 100644 --- a/compiler/src/dotty/tools/dotc/util/FreshNameCreator.scala +++ b/compiler/src/dotty/tools/dotc/util/FreshNameCreator.scala @@ -9,20 +9,27 @@ import core.StdNames.str abstract class FreshNameCreator { def newName(prefix: TermName, unique: UniqueNameKind): TermName + def currentCount(prefix: TermName, unique: UniqueNameKind): Int } object FreshNameCreator { class Default extends FreshNameCreator { - protected var counter: Int = 0 protected val counters: mutable.Map[String, Int] = mutable.AnyRefMap() withDefaultValue 0 + private def keyFor(prefix: TermName, unique: UniqueNameKind) = + str.sanitize(prefix.toString) + unique.separator + + /** The current counter for the given combination of `prefix` and `unique` */ + def currentCount(prefix: TermName, unique: UniqueNameKind): Int = + counters(keyFor(prefix, unique)) + /** * Create a fresh name with the given prefix. It is guaranteed * that the returned name has never been returned by a previous * call to this function (provided the prefix does not end in a digit). */ def newName(prefix: TermName, unique: UniqueNameKind): TermName = { - val key = str.sanitize(prefix.toString) + unique.separator + val key = keyFor(prefix, unique) counters(key) += 1 prefix.derived(unique.NumberedInfo(counters(key))) } diff --git a/docs/docs/reference/metaprogramming/inline.md b/docs/docs/reference/metaprogramming/inline.md index 9b12b9e02d30..e61c477ddc99 100644 --- a/docs/docs/reference/metaprogramming/inline.md +++ b/docs/docs/reference/metaprogramming/inline.md @@ -406,6 +406,73 @@ inline def fail(p1: => Any) = { fail(indentity("foo")) // error: failed on: indentity("foo") ``` +#### `memo` + +The `memo` method is used to avoid repeated evaluation of subcomputations. +Example: +``` +type T = ... +class C(x: T) { + def costly(x: T): Int = ??? + def f(y: Int) = memo(costly(x)) * y +} +``` +Let's assume that `costly` is a pure function that is expensive to compute. If `f` was defined +like this: +``` + def f(y: Int) = costly(x) * y +``` +the `costly(x)` subexpression would be recomputed each time `f` was called, even though +its result is the same each time. With the addition of `memo(...)` the subexpression +in the parentheses is computed only the first time and is cached for subsequent recalculuations. +The memoized program expands to the following code: +``` +class C(x: T) { + def costly(x: T): Int = ??? + private[this] var memo$1: T | Null = null + def f(y: Int) = { + if (memo$1 == null) memo$1 = costly(x) + memo$1.asInstanceOf[T] + } * y +} +``` +The fine-print behind this expansion is: + + - The caching variable is placed next to the enclosing method (`f` in this case). + - Its type is the union of the type of the cached expression and `Null`. + - Its inital value is `null`. + - A `memo(op)` call is expanded to code that tests whether the cached variable is + null, in which case it reassignes the variable with the result of evaluating `op`. + The value of `memo(op)` is the value of the cached variable after this conditional assignment. + +In simple scenarios the call to `memo` is equivalent to using `lazy val`. For instance +the example program above could be simulated like this: +``` +class C(x: T) { + def costly(x: T): Int = ??? + @threadunsafe private[this] lazy val cached = costly(x) + def f(y: Int) = cached * y +} +``` +The advantage of using `memo` over lazy vals is that it's more concise. But `memo` could also be +used in scenarios where lazy vals are not suitable. For instance, let's assume +that the methods in class `C` above also need a given `Context` parameter. +``` +class C(x: T) { + def costly(x: T) given Context: Int = ??? + def f(y: Int) given (c: Context) = memo(costly(x) given c) * y +} +``` +Now, we cannot simply pull out the computation `costly(x) given c` into a lazy val since +it depends on the parameter `c` which is only available inside `f`. On the other hand, +it's much harder to argue that the `memo` solution is correct. One possible scenario +is that we fully intend to capture and reuse only the first computation of `costly(x)`. +Another possible scenario is that we do want `memo` to be semantically invisible, used +for optimization only, but that we convince ourselves that `costly(x) given c` would return +the same value no matter what context `c` is passed to `f`. That's a much harder argument +to make, but sometimes we can derive this from the global architecture of the system we are +dealing with. + ## Implicit Matches It is foreseen that many areas of typelevel programming can be done with rewrite diff --git a/library/src/scala/compiletime/package.scala b/library/src/scala/compiletime/package.scala index 7152c2b110bd..cfe87d785218 100644 --- a/library/src/scala/compiletime/package.scala +++ b/library/src/scala/compiletime/package.scala @@ -38,4 +38,6 @@ package object compiletime { inline def constValue[T]: T = ??? type S[X <: Int] <: Int + + inline def memo[T](op: => T): T = ??? } diff --git a/tests/neg/memoTest.scala b/tests/neg/memoTest.scala new file mode 100644 index 000000000000..952767b051e8 --- /dev/null +++ b/tests/neg/memoTest.scala @@ -0,0 +1,4 @@ +object Test { + import compiletime.memo + val a = memo(1) // error: memo(...) outside method +} \ No newline at end of file diff --git a/tests/run/memoTest.check b/tests/run/memoTest.check new file mode 100644 index 000000000000..bcc6c67fe293 --- /dev/null +++ b/tests/run/memoTest.check @@ -0,0 +1,7 @@ +computing inner +computing f +1 +1 +computing f +1 +1 diff --git a/tests/run/memoTest.scala b/tests/run/memoTest.scala index 460991eeaa59..a265a816525d 100644 --- a/tests/run/memoTest.scala +++ b/tests/run/memoTest.scala @@ -1,4 +1,5 @@ object Test extends App { + import compiletime.memo var opCache: Int | Null = null @@ -7,6 +8,44 @@ object Test extends App { opCache.asInstanceOf[Int] + 1 } + def bar(x: Int) = memo(x * x) + 1 + assert(foo(1) + foo(2) == 4) + assert(bar(1) + bar(2) == 4) + trait T { + def x: Int + def y: Int = memo { + def inner = memo { + println("computing inner"); + x * x + } + inner + inner + } + } + val t = new T { + def x = 3 + assert(y == 18) + } + assert(t.y == 18) + + class Context(val n: Int) + def f(c: Context): Context = { + println("computing f") + Context(c.n + 1) + } + given as Context(0) + + locally { + given as Context given (c: Context) = memo(f(c)) + println(the[Context].n) + println(the[Context].n) + } + + val ctx = f(the[Context]) + locally { + given as Context = ctx + println(the[Context].n) + println(the[Context].n) + } } \ No newline at end of file