Skip to content

Commit eaff261

Browse files
authored
Macro annotation (part 1) (#16392)
#### Add basic support for macro annotations * Introduce experimental `scala.annotations.MacroAnnotation` * Macro annotations can analyze or modify definitions * Macro annotation can add definition around the annotated definition * Added members are not visible while typing * Added members are not visible to other macro annotations * Added definition must have the same owner * Implement macro annotation expansion * Implemented at `Inlining` phase * Can use macro annotations in staged expressions (expanded when at stage 0) * Can use staged expression to implement macro annotations * Can insert calls to inline methods in macro annotations * Current limitations (to be loosened in following PRs) * Can only be used on `def`, `val`, `lazy val` and `var` * Can only add `def`, `val`, `lazy val` and `var` definitions #### Example ```scala class memoize extends MacroAnnotation: def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = import quotes.reflect._ tree match case DefDef(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(rhsTree)) => (Ref(param.symbol).asExpr, rhsTree.asExpr) match case ('{ $paramRefExpr: t }, '{ $rhsExpr: u }) => val cacheTpe = TypeRepr.of[Map[t, u]] val cacheSymbol = Symbol.newVal(tree.symbol.owner, name + "Cache", cacheTpe, Flags.Private, Symbol.noSymbol) val cacheRhs = '{ Map.empty[t, u] }.asTerm val cacheVal = ValDef(cacheSymbol, Some(cacheRhs)) val cacheRefExpr = Ref(cacheSymbol).asExprOf[Map[t, u]] val newRhs = '{ $cacheRefExpr.getOrElseUpdate($paramRefExpr, $rhsExpr) }.asTerm val newTree = DefDef.copy(tree)(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(newRhs)) List(cacheVal, newTree) case _ => report.error("Annotation only supported on `def` with a single argument are supported") List(tree) ``` with this macro annotation a user can write ```scala @memoize def fib(n: Int): Int = println(s"compute fib of $n") if n <= 1 then n else fib(n - 1) + fib(n - 2) ``` and the macro will modify the definition to create ```scala val fibCache = mutable.Map.empty[Int, Int] def fib(n: Int): Int = fibCache.getOrElseUpdate( n, { println(s"compute fib of $n") if n <= 1 then n else fib(n - 1) + fib(n - 2) } ) ``` #### Based on * #15626 * https://infoscience.epfl.ch/record/294615?ln=en #### Followed by * #16454
2 parents 281fd99 + 19e37c8 commit eaff261

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1022
-20
lines changed

compiler/src/dotty/tools/dotc/CompilationUnit.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import core.Decorators._
1616
import config.{SourceVersion, Feature}
1717
import StdNames.nme
1818
import scala.annotation.internal.sharable
19+
import transform.MacroAnnotations
1920

2021
class CompilationUnit protected (val source: SourceFile) {
2122

@@ -45,6 +46,8 @@ class CompilationUnit protected (val source: SourceFile) {
4546
*/
4647
var needsInlining: Boolean = false
4748

49+
var hasMacroAnnotations: Boolean = false
50+
4851
/** Set to `true` if inliner added anonymous mirrors that need to be completed */
4952
var needsMirrorSupport: Boolean = false
5053

@@ -119,6 +122,7 @@ object CompilationUnit {
119122
force.traverse(unit1.tpdTree)
120123
unit1.needsStaging = force.containsQuote
121124
unit1.needsInlining = force.containsInline
125+
unit1.hasMacroAnnotations = force.containsMacroAnnotation
122126
}
123127
unit1
124128
}
@@ -147,6 +151,7 @@ object CompilationUnit {
147151
var containsQuote = false
148152
var containsInline = false
149153
var containsCaptureChecking = false
154+
var containsMacroAnnotation = false
150155
def traverse(tree: Tree)(using Context): Unit = {
151156
if (tree.symbol.isQuote)
152157
containsQuote = true
@@ -160,6 +165,9 @@ object CompilationUnit {
160165
Feature.handleGlobalLanguageImport(prefix, imported)
161166
case _ =>
162167
case _ =>
168+
for annot <- tree.symbol.annotations do
169+
if MacroAnnotations.isMacroAnnotation(annot) then
170+
ctx.compilationUnit.hasMacroAnnotations = true
163171
traverseChildren(tree)
164172
}
165173
}

compiler/src/dotty/tools/dotc/config/Printers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ object Printers {
3232
val init = noPrinter
3333
val inlining = noPrinter
3434
val interactiv = noPrinter
35+
val macroAnnot = noPrinter
3536
val matchTypes = noPrinter
3637
val nullables = noPrinter
3738
val overload = noPrinter

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,8 @@ class Definitions {
895895
@tu lazy val QuotedTypeModule: Symbol = QuotedTypeClass.companionModule
896896
@tu lazy val QuotedTypeModule_of: Symbol = QuotedTypeModule.requiredMethod("of")
897897

898+
@tu lazy val MacroAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MacroAnnotation")
899+
898900
@tu lazy val CanEqualClass: ClassSymbol = getClassIfDefined("scala.Eql").orElse(requiredClass("scala.CanEqual")).asClass
899901
def CanEqual_canEqualAny(using Context): TermSymbol =
900902
val methodName = if CanEqualClass.name == tpnme.Eql then nme.eqlAny else nme.canEqualAny

compiler/src/dotty/tools/dotc/quoted/Interpreter.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import dotty.tools.dotc.reporting.Message
3232
import dotty.tools.repl.AbstractFileClassLoader
3333

3434
/** Tree interpreter for metaprogramming constructs */
35-
abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context):
35+
class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context):
3636
import Interpreter._
3737
import tpd._
3838

@@ -68,7 +68,7 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
6868

6969
// TODO disallow interpreted method calls as arguments
7070
case Call(fn, args) =>
71-
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package))
71+
if (fn.symbol.isConstructor)
7272
interpretNew(fn.symbol, args.flatten.map(interpretTree))
7373
else if (fn.symbol.is(Module))
7474
interpretModuleAccess(fn.symbol)
@@ -185,8 +185,9 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
185185
private def interpretModuleAccess(fn: Symbol): Object =
186186
loadModule(fn.moduleClass)
187187

188-
private def interpretNew(fn: Symbol, args: => List[Object]): Object = {
189-
val clazz = loadClass(fn.owner.fullName.toString)
188+
private def interpretNew(fn: Symbol, args: List[Object]): Object = {
189+
val className = fn.owner.fullName.mangledString.replaceAll("\\$\\.", "\\$")
190+
val clazz = loadClass(className)
190191
val constr = clazz.getConstructor(paramsSig(fn): _*)
191192
constr.newInstance(args: _*).asInstanceOf[Object]
192193
}
@@ -214,10 +215,6 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
214215
private def loadClass(name: String): Class[?] =
215216
try classLoader.loadClass(name)
216217
catch {
217-
case _: ClassNotFoundException if ctx.compilationUnit.isSuspendable =>
218-
if (ctx.settings.XprintSuspension.value)
219-
report.echo(i"suspension triggered by a dependency on $name", pos)
220-
ctx.compilationUnit.suspend()
221218
case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable =>
222219
if (ctx.settings.XprintSuspension.value)
223220
report.echo(i"suspension triggered by a dependency on $sym", pos)
@@ -272,13 +269,15 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
272269
}
273270

274271
private object MissingClassDefinedInCurrentRun {
275-
def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = {
276-
val className = targetException.getMessage
277-
if (className eq null) None
278-
else {
279-
val sym = staticRef(className.toTypeName).symbol
280-
if (sym.isDefinedInCurrentRun) Some(sym) else None
281-
}
272+
def unapply(targetException: Throwable)(using Context): Option[Symbol] = {
273+
targetException match
274+
case _: NoClassDefFoundError | _: ClassNotFoundException =>
275+
val className = targetException.getMessage
276+
if className eq null then None
277+
else
278+
val sym = staticRef(className.toTypeName).symbol
279+
if (sym.isDefinedInCurrentRun) Some(sym) else None
280+
case _ => None
282281
}
283282
}
284283

compiler/src/dotty/tools/dotc/transform/Inlining.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,18 @@ import Contexts._
77
import Symbols._
88
import SymUtils._
99
import dotty.tools.dotc.ast.tpd
10-
10+
import dotty.tools.dotc.ast.Trees._
11+
import dotty.tools.dotc.quoted._
1112
import dotty.tools.dotc.core.StagingContext._
1213
import dotty.tools.dotc.inlines.Inlines
1314
import dotty.tools.dotc.ast.TreeMapWithImplicits
15+
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
1416

1517

1618
/** Inlines all calls to inline methods that are not in an inline method or a quote */
17-
class Inlining extends MacroTransform {
19+
class Inlining extends MacroTransform with IdentityDenotTransformer {
20+
thisPhase =>
21+
1822
import tpd._
1923

2024
override def phaseName: String = Inlining.name
@@ -23,8 +27,10 @@ class Inlining extends MacroTransform {
2327

2428
override def allowsImplicitSearch: Boolean = true
2529

30+
override def changesMembers: Boolean = true
31+
2632
override def run(using Context): Unit =
27-
if ctx.compilationUnit.needsInlining then
33+
if ctx.compilationUnit.needsInlining || ctx.compilationUnit.hasMacroAnnotations then
2834
try super.run
2935
catch case _: CompilationUnit.SuspendException => ()
3036

@@ -59,8 +65,16 @@ class Inlining extends MacroTransform {
5965
private class InliningTreeMap extends TreeMapWithImplicits {
6066
override def transform(tree: Tree)(using Context): Tree = {
6167
tree match
62-
case tree: DefTree =>
68+
case tree: MemberDef =>
6369
if tree.symbol.is(Inline) then tree
70+
else if tree.symbol.is(Param) then super.transform(tree)
71+
else if
72+
!tree.symbol.isPrimaryConstructor
73+
&& StagingContext.level == 0
74+
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
75+
then
76+
val trees = new MacroAnnotations(thisPhase).expandAnnotations(tree)
77+
flatTree(trees.map(super.transform))
6478
else super.transform(tree)
6579
case _: Typed | _: Block =>
6680
super.transform(tree)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import scala.language.unsafeNulls
5+
6+
import dotty.tools.dotc.ast.tpd
7+
import dotty.tools.dotc.ast.Trees.*
8+
import dotty.tools.dotc.config.Printers.{macroAnnot => debug}
9+
import dotty.tools.dotc.core.Annotations.*
10+
import dotty.tools.dotc.core.Contexts.*
11+
import dotty.tools.dotc.core.Decorators.*
12+
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
13+
import dotty.tools.dotc.core.Flags.*
14+
import dotty.tools.dotc.core.MacroClassLoader
15+
import dotty.tools.dotc.core.Symbols.*
16+
import dotty.tools.dotc.quoted.*
17+
import dotty.tools.dotc.util.SrcPos
18+
import scala.quoted.runtime.impl.{QuotesImpl, SpliceScope}
19+
20+
import scala.quoted.Quotes
21+
22+
class MacroAnnotations(thisPhase: DenotTransformer):
23+
import tpd.*
24+
import MacroAnnotations.*
25+
26+
/** Expands every macro annotation that is on this tree.
27+
* Returns a list with transformed definition and any added definitions.
28+
*/
29+
def expandAnnotations(tree: MemberDef)(using Context): List[DefTree] =
30+
if !hasMacroAnnotation(tree.symbol) then
31+
List(tree)
32+
else if tree.symbol.is(Module) then
33+
if tree.symbol.isClass then // error only reported on module class
34+
report.error("macro annotations are not supported on object", tree)
35+
List(tree)
36+
else if tree.symbol.isClass then
37+
report.error("macro annotations are not supported on class", tree)
38+
List(tree)
39+
else if tree.symbol.isType then
40+
report.error("macro annotations are not supported on type", tree)
41+
List(tree)
42+
else
43+
debug.println(i"Expanding macro annotations of:\n$tree")
44+
45+
val macroInterpreter = new Interpreter(tree.srcPos, MacroClassLoader.fromContext)
46+
47+
val allTrees = List.newBuilder[DefTree]
48+
var insertedAfter: List[List[DefTree]] = Nil
49+
50+
// Apply all macro annotation to `tree` and collect new definitions in order
51+
val transformedTree: DefTree = tree.symbol.annotations.foldLeft(tree) { (tree, annot) =>
52+
if isMacroAnnotation(annot) then
53+
debug.println(i"Expanding macro annotation: ${annot}")
54+
55+
// Interpret call to `new myAnnot(..).transform(using <Quotes>)(<tree>)`
56+
val transformedTrees = callMacro(macroInterpreter, tree, annot)
57+
transformedTrees.span(_.symbol != tree.symbol) match
58+
case (prefixed, newTree :: suffixed) =>
59+
allTrees ++= prefixed
60+
insertedAfter = suffixed :: insertedAfter
61+
prefixed.foreach(checkAndEnter(_, tree.symbol, annot))
62+
suffixed.foreach(checkAndEnter(_, tree.symbol, annot))
63+
newTree
64+
case (Nil, Nil) =>
65+
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
66+
tree
67+
case (_, Nil) =>
68+
report.error(i"Transformed tree for ${tree} was not return by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
69+
tree
70+
else
71+
tree
72+
}
73+
74+
allTrees += transformedTree
75+
insertedAfter.foreach(allTrees.++=)
76+
77+
val result = allTrees.result()
78+
debug.println(result.map(_.show).mkString("expanded to:\n", "\n", ""))
79+
result
80+
81+
/** Interpret the code `new annot(..).transform(using <Quotes(ctx)>)(<tree>)` */
82+
private def callMacro(interpreter: Interpreter, tree: MemberDef, annot: Annotation)(using Context): List[MemberDef] =
83+
// TODO: Remove when scala.annaotaion.MacroAnnotation is no longer experimental
84+
import scala.reflect.Selectable.reflectiveSelectable
85+
type MacroAnnotation = {
86+
def transform(using Quotes)(tree: Object/*Erased type of quotes.refelct.Definition*/): List[MemberDef /*quotes.refelct.Definition known to be MemberDef in QuotesImpl*/]
87+
}
88+
89+
// Interpret macro annotation instantiation `new myAnnot(..)`
90+
val annotInstance = interpreter.interpret[MacroAnnotation](annot.tree).get
91+
// TODO: Remove when scala.annaotaion.MacroAnnotation is no longer experimental
92+
assert(annotInstance.getClass.getClassLoader.loadClass("scala.annotation.MacroAnnotation").isInstance(annotInstance))
93+
94+
val quotes = QuotesImpl()(using SpliceScope.contextWithNewSpliceScope(tree.symbol.sourcePos)(using MacroExpansion.context(tree)).withOwner(tree.symbol))
95+
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])
96+
97+
/** Check that this tree can be added by the macro annotation and enter it if needed */
98+
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
99+
val sym = newTree.symbol
100+
if sym.isClass then
101+
report.error("Generating classes is not supported", annot.tree)
102+
else if sym.isType then
103+
report.error("Generating type is not supported", annot.tree)
104+
else if sym.owner != annotated.owner then
105+
report.error(i"macro annotation $annot added $sym with an inconsistent owner. Expected it to be owned by ${annotated.owner} but was owned by ${sym.owner}.", annot.tree)
106+
else
107+
sym.enteredAfter(thisPhase)
108+
109+
object MacroAnnotations:
110+
111+
/** Is this an annotation that implements `scala.annation.MacroAnnotation` */
112+
def isMacroAnnotation(annot: Annotation)(using Context): Boolean =
113+
annot.tree.symbol.maybeOwner.derivesFrom(defn.MacroAnnotationClass)
114+
115+
/** Is this symbol annotated with an annotation that implements `scala.annation.MacroAnnotation` */
116+
def hasMacroAnnotation(sym: Symbol)(using Context): Boolean =
117+
sym.getAnnotation(defn.MacroAnnotationClass).isDefined

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,21 +375,25 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
375375
)
376376
}
377377
case tree: ValDef =>
378+
registerIfHasMacroAnnotations(tree)
378379
checkErasedDef(tree)
379380
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
380381
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
381382
checkStableSelection(tree.rhs)
382383
processValOrDefDef(super.transform(tree1))
383384
case tree: DefDef =>
385+
registerIfHasMacroAnnotations(tree)
384386
checkErasedDef(tree)
385387
annotateContextResults(tree)
386388
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
387389
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
388390
case tree: TypeDef =>
391+
registerIfHasMacroAnnotations(tree)
389392
val sym = tree.symbol
390393
if (sym.isClass)
391394
VarianceChecker.check(tree)
392395
annotateExperimental(sym)
396+
checkMacroAnnotation(sym)
393397
tree.rhs match
394398
case impl: Template =>
395399
for parent <- impl.parents do
@@ -483,6 +487,16 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
483487
private def normalizeErasedRhs(rhs: Tree, sym: Symbol)(using Context) =
484488
if (sym.isEffectivelyErased) dropInlines.transform(rhs) else rhs
485489

490+
/** Check if the definition has macro annotation and sets `compilationUnit.hasMacroAnnotations` if needed. */
491+
private def registerIfHasMacroAnnotations(tree: DefTree)(using Context) =
492+
if !Inlines.inInlineMethod && MacroAnnotations.hasMacroAnnotation(tree.symbol) then
493+
ctx.compilationUnit.hasMacroAnnotations = true
494+
495+
/** Check macro annotations implementations */
496+
private def checkMacroAnnotation(sym: Symbol)(using Context) =
497+
if sym.derivesFrom(defn.MacroAnnotationClass) && !sym.isStatic then
498+
report.error("classes that extend MacroAnnotation must not be inner/local classes", sym.srcPos)
499+
486500
private def checkErasedDef(tree: ValOrDefDef)(using Context): Unit =
487501
if tree.symbol.is(Erased, butNot = Macro) then
488502
val tpe = tree.rhs.tpe

compiler/src/dotty/tools/dotc/transform/YCheckPositions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class YCheckPositions extends Phase {
6161

6262
private def isMacro(call: Tree)(using Context) =
6363
call.symbol.is(Macro) ||
64+
(call.symbol.isClass && call.tpe.derivesFrom(defn.MacroAnnotationClass)) ||
6465
// The call of a macro after typer is encoded as a Select while other inlines are Ident
6566
// TODO remove this distinction once Inline nodes of expanded macros can be trusted (also in Inliner.inlineCallTrace)
6667
(!(ctx.phase <= postTyperPhase) && call.isInstanceOf[Select])

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,8 +2485,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
24852485
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
24862486
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
24872487
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | dotc.core.Flags.Method, tpe, privateWithin)
2488+
def newUniqueMethod(owner: Symbol, namePrefix: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2489+
val name = NameKinds.UniqueName.fresh(namePrefix.toTermName)
2490+
dotc.core.Symbols.newSymbol(owner, name, dotc.core.Flags.PrivateMethod | flags, tpe, privateWithin)
24882491
def newVal(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
24892492
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags, tpe, privateWithin)
2493+
def newUniqueVal(owner: Symbol, namePrefix: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
2494+
val name = NameKinds.UniqueName.fresh(namePrefix.toTermName)
2495+
dotc.core.Symbols.newSymbol(owner, name, flags, tpe, privateWithin)
24902496
def newBind(owner: Symbol, name: String, flags: Flags, tpe: TypeRepr): Symbol =
24912497
dotc.core.Symbols.newSymbol(owner, name.toTermName, flags | Case, tpe)
24922498
def noSymbol: Symbol = dotc.core.Symbols.NoSymbol

0 commit comments

Comments
 (0)