Skip to content

Commit 76e8661

Browse files
committed
Macro annotation (part 1)
Add basic support for macro annotations. * Introduce experimental `scala.annotations.MacroAnnotation` * Macro annotations can analyze or modify definitions * Macro annotation can add definition around the annotated definition * Added members are not visible while typing * Added members are not visible to other macro annotations * Implement macro annotation expansion * Implemented at Inlining phase * Can use macro annotations in staged expressions (expanded when at stage 0) * Can use staged expression to implement macro annotations * Can insert calls to inline methods in macro annotations * Current limitations (to be loosened) * Can only be used on `def`, `val`, `lazy val` and `var` * Can only add `def`, `val`, `lazy val` and `var` definitions Based on: * scala#15626 * https://infoscience.epfl.ch/record/294615?ln=en
1 parent b4f8eef commit 76e8661

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+908
-6
lines changed

compiler/src/dotty/tools/dotc/CompilationUnit.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package dotc
33

44
import core._
55
import Contexts._
6-
import SymDenotations.ClassDenotation
6+
import SymDenotations.{ClassDenotation, NoDenotation}
77
import Symbols._
88
import util.{FreshNameCreator, SourceFile, NoSource}
99
import util.Spans.Span
@@ -45,6 +45,8 @@ class CompilationUnit protected (val source: SourceFile) {
4545
*/
4646
var needsInlining: Boolean = false
4747

48+
var hasMacroAnnotations: Boolean = false
49+
4850
/** Set to `true` if inliner added anonymous mirrors that need to be completed */
4951
var needsMirrorSupport: Boolean = false
5052

@@ -119,6 +121,7 @@ object CompilationUnit {
119121
force.traverse(unit1.tpdTree)
120122
unit1.needsStaging = force.containsQuote
121123
unit1.needsInlining = force.containsInline
124+
unit1.hasMacroAnnotations = force.containsMacroAnnotation
122125
}
123126
unit1
124127
}
@@ -147,6 +150,7 @@ object CompilationUnit {
147150
var containsQuote = false
148151
var containsInline = false
149152
var containsCaptureChecking = false
153+
var containsMacroAnnotation = false
150154
def traverse(tree: Tree)(using Context): Unit = {
151155
if (tree.symbol.isQuote)
152156
containsQuote = true
@@ -160,6 +164,9 @@ object CompilationUnit {
160164
Feature.handleGlobalLanguageImport(prefix, imported)
161165
case _ =>
162166
case _ =>
167+
for annot <- tree.symbol.annotations do
168+
if annot.tree.symbol.denot != NoDenotation && annot.tree.symbol.owner.derivesFrom(defn.QuotedMacroAnnotationClass) then
169+
ctx.compilationUnit.hasMacroAnnotations = true
163170
traverseChildren(tree)
164171
}
165172
}

compiler/src/dotty/tools/dotc/config/Printers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ object Printers {
3232
val init = noPrinter
3333
val inlining = noPrinter
3434
val interactiv = noPrinter
35+
val macroAnnot = noPrinter
3536
val matchTypes = noPrinter
3637
val nullables = noPrinter
3738
val overload = noPrinter

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,9 @@ class Definitions {
891891
@tu lazy val QuotedTypeModule: Symbol = QuotedTypeClass.companionModule
892892
@tu lazy val QuotedTypeModule_of: Symbol = QuotedTypeModule.requiredMethod("of")
893893

894+
@tu lazy val QuotedMacroAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MacroAnnotation")
895+
@tu lazy val QuotedMacroAnnotation_transform: Symbol = QuotedMacroAnnotationClass.requiredMethod("transform")
896+
894897
@tu lazy val CanEqualClass: ClassSymbol = getClassIfDefined("scala.Eql").orElse(requiredClass("scala.CanEqual")).asClass
895898
def CanEqual_canEqualAny(using Context): TermSymbol =
896899
val methodName = if CanEqualClass.name == tpnme.Eql then nme.eqlAny else nme.canEqualAny

compiler/src/dotty/tools/dotc/quoted/Interpreter.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
6262
case ConstantType(c) => c.value.asInstanceOf[Object]
6363
case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos)
6464

65+
case Apply(Select(New(_), _), args) =>
66+
val interpretedArgs = args.map(interpretTree)
67+
interpretNew(tree.symbol, interpretedArgs)
68+
6569
// TODO disallow interpreted method calls as arguments
6670
case Call(fn, args) =>
6771
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package))
@@ -174,6 +178,12 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
174178
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method)
175179
}
176180

181+
protected def interpretedMethodCall(inst: Object, fn: Symbol)(args: Object*)(implicit env: Env): Object = {
182+
val name = fn.name.asTermName
183+
val method = getMethod(inst.getClass, name, paramsSig(fn))
184+
stopIfRuntimeException(method.invoke(inst, args: _*), method)
185+
}
186+
177187
private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = {
178188
val clazz = loadClass(sym.owner.fullName.toString)
179189
val field = clazz.getField(sym.name.toString)
@@ -184,7 +194,8 @@ abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context)
184194
loadModule(fn.moduleClass)
185195

186196
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
187-
val clazz = loadClass(fn.owner.fullName.toString)
197+
val className = fn.owner.fullName.toString.replaceAll("\\$\\.", "\\$")
198+
val clazz = loadClass(className)
188199
val constr = clazz.getConstructor(paramsSig(fn): _*)
189200
constr.newInstance(args: _*).asInstanceOf[Object]
190201
}

compiler/src/dotty/tools/dotc/transform/Inlining.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@ import Contexts._
77
import Symbols._
88
import SymUtils._
99
import dotty.tools.dotc.ast.tpd
10-
10+
import dotty.tools.dotc.ast.Trees._
11+
import dotty.tools.dotc.quoted._
1112
import dotty.tools.dotc.core.StagingContext._
1213
import dotty.tools.dotc.inlines.Inlines
1314
import dotty.tools.dotc.ast.TreeMapWithImplicits
15+
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
1416

1517

1618
/** Inlines all calls to inline methods that are not in an inline method or a quote */
17-
class Inlining extends MacroTransform {
19+
class Inlining extends MacroTransform with IdentityDenotTransformer {
1820
import tpd._
1921

2022
override def phaseName: String = Inlining.name
@@ -23,8 +25,10 @@ class Inlining extends MacroTransform {
2325

2426
override def allowsImplicitSearch: Boolean = true
2527

28+
override def changesMembers: Boolean = true
29+
2630
override def run(using Context): Unit =
27-
if ctx.compilationUnit.needsInlining then
31+
if ctx.compilationUnit.needsInlining || ctx.compilationUnit.hasMacroAnnotations then
2832
try super.run
2933
catch case _: CompilationUnit.SuspendException => ()
3034

@@ -61,7 +65,17 @@ class Inlining extends MacroTransform {
6165
tree match
6266
case tree: DefTree =>
6367
if tree.symbol.is(Inline) then tree
64-
else super.transform(tree)
68+
else
69+
tree match
70+
case _: Bind => super.transform(tree)
71+
case tree if tree.symbol.is(Param) => super.transform(tree)
72+
case tree
73+
if !tree.symbol.isPrimaryConstructor
74+
&& StagingContext.level == 0
75+
&& MacroAnnotations.hasMacro(tree.symbol) =>
76+
val trees = new MacroAnnotations(Inlining.this).transform(tree)
77+
flatTree(trees.map(super.transform))
78+
case tree => super.transform(tree)
6579
case _: Typed | _: Block =>
6680
super.transform(tree)
6781
case _ if Inlines.needsInlining(tree) =>
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import scala.language.unsafeNulls
5+
6+
import dotty.tools.dotc.ast.tpd
7+
import dotty.tools.dotc.ast.Trees.*
8+
import dotty.tools.dotc.config.Printers.{macroAnnot => debug}
9+
import dotty.tools.dotc.core.Annotations.*
10+
import dotty.tools.dotc.core.Contexts.*
11+
import dotty.tools.dotc.core.Decorators.*
12+
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
13+
import dotty.tools.dotc.core.Flags.*
14+
import dotty.tools.dotc.core.MacroClassLoader
15+
import dotty.tools.dotc.core.Symbols.*
16+
import dotty.tools.dotc.core.SymDenotations.NoDenotation
17+
import dotty.tools.dotc.quoted.*
18+
import dotty.tools.dotc.util.SrcPos
19+
import scala.quoted.runtime.impl.{QuotesImpl, SpliceScope}
20+
21+
class MacroAnnotations(thisPhase: DenotTransformer):
22+
import tpd.*
23+
import MacroAnnotations.*
24+
25+
def transform(tree: Tree)(using Context): List[Tree] =
26+
if !hasMacro(tree.symbol) then
27+
List(tree)
28+
else if tree.symbol.is(Module) then
29+
if tree.symbol.isClass then // error only reported on module class
30+
report.error("Macro annotations are not supported on object", tree)
31+
List(tree)
32+
else if tree.symbol.isClass then
33+
report.error("Macro annotations are not supported on class", tree)
34+
List(tree)
35+
else if tree.symbol.isType then
36+
report.error("Macro annotations are not supported on type", tree)
37+
List(tree)
38+
else
39+
debug.println(i"Expanding macro annotations of:\n$tree")
40+
41+
val macroInterpreter = new InterpreterMacroAnnot(tree.srcPos, MacroClassLoader.fromContext)
42+
43+
val allTrees = List.newBuilder[Tree]
44+
var insertedAfter: List[List[Tree]] = Nil
45+
46+
// Apply all macro annotation to `tree` and collect new definitions in order
47+
val transformedTree: Tree = tree.symbol.annotations.foldLeft(tree) { (tree, annot) =>
48+
if isMacroAnnotation(annot) then
49+
debug.println(i"Expanding macro annotation: ${annot}")
50+
51+
// Interpret call to `new myAnnot(..).transform(using <Quotes>)(<tree>)`
52+
val transformedTrees = callMacro(macroInterpreter, tree, annot)
53+
transformedTrees.span(_.symbol != tree.symbol) match
54+
case (prefixed, newTree :: suffixed) =>
55+
allTrees ++= prefixed
56+
insertedAfter = suffixed :: insertedAfter
57+
prefixed.foreach(checkAndEnter(_, annot))
58+
suffixed.foreach(checkAndEnter(_, annot))
59+
newTree
60+
case (Nil, Nil) =>
61+
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
62+
tree
63+
case (_, Nil) =>
64+
report.error(i"Transformed tree for ${tree} was not return by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
65+
tree
66+
else
67+
tree
68+
}
69+
70+
allTrees += transformedTree
71+
insertedAfter.foreach(allTrees.++=)
72+
73+
val result = allTrees.result()
74+
debug.println(result.map(_.show).mkString("expanded to:\n", "\n", ""))
75+
result
76+
77+
/** Interpret the code `new annot(..).transform(using <Quotes(ctx)>)(<tree>)` */
78+
private def callMacro(interpreter: InterpreterMacroAnnot, tree: Tree, annot: Annotation)(using Context): List[Tree] =
79+
// Interpret macro annotation instantiation `new myAnnot(..)`
80+
// TODO: Interpret as MacroAnnotation when no longer experimental
81+
val annotInstance = interpreter.interpret[Object/*MacroAnnotation*/](annot.tree).get
82+
assert(annotInstance.getClass.getClassLoader.loadClass("scala.annotation.MacroAnnotation").isInstance(annotInstance))
83+
84+
// Call transform `new annot(..).transform(using <Quotes(ctx)>)(<tree>)`
85+
val quotes = QuotesImpl()(using SpliceScope.contextWithNewSpliceScope(tree.symbol.sourcePos)(using MacroExpansion.context(tree)).withOwner(tree.symbol))
86+
// TODO: Call MacroAnnotation.transform directly when no longer experimental
87+
val transformedTrees = interpreter.interpretedMethodCall(annotInstance, defn.QuotedMacroAnnotation_transform)(quotes, tree)(Map.empty)
88+
.asInstanceOf[List[Tree]]
89+
assert(transformedTrees.forall(_.isInstanceOf[Tree]))
90+
transformedTrees
91+
92+
/** Check that this tree can be added by the macro annotation and enter it if needed */
93+
private def checkAndEnter(newTree: Tree, annot: Annotation)(using Context) =
94+
if newTree.symbol.isClass then
95+
report.error("Generating classes is not supported", annot.tree)
96+
else if newTree.symbol.isType then
97+
report.error("Generating type is not supported", annot.tree)
98+
newTree.symbol.enteredAfter(thisPhase)
99+
100+
object MacroAnnotations:
101+
102+
def isMacroAnnotation(annot: Annotation)(using Context): Boolean =
103+
val sym = annot.tree.symbol
104+
sym.denot != NoDenotation && sym.owner.derivesFrom(defn.QuotedMacroAnnotationClass)
105+
106+
def hasMacro(sym: Symbol)(using Context): Boolean =
107+
sym.getAnnotation(defn.QuotedMacroAnnotationClass).isDefined
108+
109+
// TODO: Remove InterpreterMacroAnnot and use Interpreter directly when InterpreterMacroAnnot is no longer experimental
110+
private[MacroAnnotations] class InterpreterMacroAnnot(pos: SrcPos, classLoader: ClassLoader)(using Context) extends Interpreter(pos, classLoader):
111+
override def interpretedMethodCall(inst: Object, fn: Symbol)(args: Object*)(implicit env: Env): Object =
112+
super.interpretedMethodCall(inst, fn)(args*)

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,17 +375,20 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
375375
)
376376
}
377377
case tree: ValDef =>
378+
checkAnnotationMacros(tree)
378379
checkErasedDef(tree)
379380
val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
380381
if tree1.removeAttachment(desugar.UntupledParam).isDefined then
381382
checkStableSelection(tree.rhs)
382383
processValOrDefDef(super.transform(tree1))
383384
case tree: DefDef =>
385+
checkAnnotationMacros(tree)
384386
checkErasedDef(tree)
385387
annotateContextResults(tree)
386388
val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
387389
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
388390
case tree: TypeDef =>
391+
checkAnnotationMacros(tree)
389392
val sym = tree.symbol
390393
if (sym.isClass)
391394
VarianceChecker.check(tree)
@@ -483,6 +486,11 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
483486
private def normalizeErasedRhs(rhs: Tree, sym: Symbol)(using Context) =
484487
if (sym.isEffectivelyErased) dropInlines.transform(rhs) else rhs
485488

489+
private def checkAnnotationMacros(tree: Tree)(using Context) =
490+
if !ctx.compilationUnit.hasMacroAnnotations then
491+
ctx.compilationUnit.hasMacroAnnotations |=
492+
tree.symbol.annotations.exists(MacroAnnotations.isMacroAnnotation)
493+
486494
private def checkErasedDef(tree: ValOrDefDef)(using Context): Unit =
487495
if tree.symbol.is(Erased, butNot = Macro) then
488496
val tpe = tree.rhs.tpe

compiler/src/dotty/tools/dotc/transform/YCheckPositions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class YCheckPositions extends Phase {
6161

6262
private def isMacro(call: Tree)(using Context) =
6363
call.symbol.is(Macro) ||
64+
(call.symbol.isClass && call.tpe.derivesFrom(defn.QuotedMacroAnnotationClass)) ||
6465
// The call of a macro after typer is encoded as a Select while other inlines are Ident
6566
// TODO remove this distinction once Inline nodes of expanded macros can be trusted (also in Inliner.inlineCallTrace)
6667
(!(ctx.phase <= postTyperPhase) && call.isInstanceOf[Select])
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// TODO in which package should this class be located?
2+
package scala
3+
package annotation
4+
5+
import scala.quoted._
6+
7+
/** Base trait for macro annotation that will transform a definition */
8+
@experimental
9+
trait MacroAnnotation extends StaticAnnotation:
10+
11+
/** Transform the `tree` definition and add other definitions
12+
*
13+
* This method takes as argument the reflected representation of the annotated definition.
14+
* It returns a non-empty list containing the modified version of the annotated definition.
15+
* The new tree for the definition must use the original symbol.
16+
* New definitions can be added to the list before or after the transformed definitions, this order
17+
* will be retained.
18+
*
19+
* The result cannot contain `class`, `object` or `type` definition. This limitation will be relaxed in the future.
20+
*
21+
* Example:
22+
* ```scala
23+
* class memoize extends MacroAnnotation:
24+
* override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
25+
* import quotes.reflect._
26+
* tree match
27+
* case DefDef(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(rhsTree)) =>
28+
* (Ref(param.symbol).asExpr, rhsTree.asExpr) match
29+
* case ('{ $paramRefExpr: t }, '{ $rhsExpr: u }) =>
30+
* val cacheSymbol = Symbol.newVal(tree.symbol.owner, name + "Cache", TypeRepr.of[Map[t, u]], Flags.Private, Symbol.noSymbol)
31+
* val cacheRhs = '{ Map.empty[t, u] }.asTerm
32+
* val cacheVal = ValDef(cacheSymbol, Some(cacheRhs))
33+
* val cacheRefExpr = Ref(cacheSymbol).asExprOf[Map[t, u]]
34+
* val newRhs = '{ $cacheRefExpr.getOrElseUpdate($paramRefExpr, $rhsExpr) }.asTerm
35+
* val newTree = DefDef.copy(tree)(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(newRhs))
36+
* List(cacheVal, newTree)
37+
* case _ =>
38+
* report.error("Annotation only supported on `def` with a single argument are supported")
39+
* List(tree)
40+
* ```
41+
*
42+
* with this macro annotation a user can write
43+
* ```scala
44+
* @memoize
45+
* def fib(n: Int): Int =
46+
* println(s"compute fib of $n")
47+
* if n <= 1 then n else fib(n - 1) + fib(n - 2)
48+
* ```
49+
* and the macro will modify the definition to create
50+
* ```scala
51+
* val fibCache = mutable.Map.empty[Int, Int]
52+
* def fib(n: Int): Int =
53+
* fibCache.getOrElseUpdate(
54+
* n,
55+
* {
56+
* println(s"compute fib of $n")
57+
* if n <= 1 then n else fib(n - 1) + fib(n - 2)
58+
* }
59+
* )
60+
* ```
61+
*
62+
* @param Quotes Implicit instance of Quotes used for tree reflection
63+
* @param tree Tree that will be transformed
64+
*/
65+
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-- [E042] Type Error: tests/neg-macros/annot-MacroAnnotation-direct.scala:3:0 ------------------------------------------
2+
3 |@MacroAnnotation // error
3+
|^^^^^^^^^^^^^^^^
4+
|MacroAnnotation is a trait; it cannot be instantiated
5+
|
6+
| longer explanation available when compiling with `-explain`
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import scala.annotation.MacroAnnotation
2+
3+
@MacroAnnotation // error
4+
def test = ()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.annotation.experimental
2+
import scala.quoted._
3+
import scala.annotation.MacroAnnotation
4+
import scala.collection.mutable.Map
5+
6+
@experimental
7+
class hello extends MacroAnnotation {
8+
override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
9+
import quotes.reflect._
10+
val helloSymbol = Symbol.newVal(tree.symbol.owner, "hello", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol)
11+
val helloVal = ValDef(helloSymbol, Some(Literal(StringConstant("Hello, World!"))))
12+
List(helloVal, tree)
13+
}

0 commit comments

Comments
 (0)