Skip to content

Commit 4f39236

Browse files
Allow macro annotation to transform companion (#19677)
### Allow MacroAnnotations to update the companion of a definition We extend the MacroAnnotation api to allow to modify the companion of a class or an object. ### Specification 1. Order of expansion - We expand the definitions in program order. - We expand the annotations of the outer scope first, then we expand the inner definitions. - Annotations are expanded from the outer annotation to the inner annotation. In the following example, we expand the annotations in this order: `a1`, `a2`, `a3`. ```scala @A1 @a2 class Foo: @A3 def foo = ??? ``` 2. Expansion of the companion We always expand the latest available tree. If an annotation defined on `class Foo` changes its companion (`object Foo`) and the `class` is defined before `object`, the expansion of the annotations on the `object` will be performed on the result of the expansion of `class`. 3. The program order is maintained We maintain the program order in the definitions that were expanded. 4. Backtrack and reprocess Example: ```scala @A1 class Foo @a2 object Foo ``` If the `@a2` annotation changes the definitions in `class Foo`, we will rerun the algorithm on the result of this new expansion. Please note that we don't allow to generate code with MacroAnnotations, the reason for rerunning the algorithm is to expand and inline possible macros that we generated. --- Closes #19676
2 parents e2c456f + 4694b3b commit 4f39236

File tree

75 files changed

+631
-350
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+631
-350
lines changed

compiler/src/dotty/tools/dotc/CompilationUnit.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import config.{SourceVersion, Feature}
1717
import StdNames.nme
1818
import scala.annotation.internal.sharable
1919
import scala.util.control.NoStackTrace
20-
import transform.MacroAnnotations
20+
import transform.MacroAnnotations.isMacroAnnotation
2121

2222
class CompilationUnit protected (val source: SourceFile, val info: CompilationUnitInfo | Null) {
2323

@@ -197,7 +197,7 @@ object CompilationUnit {
197197
case _ =>
198198
case _ =>
199199
for annot <- tree.symbol.annotations do
200-
if MacroAnnotations.isMacroAnnotation(annot) then
200+
if annot.isMacroAnnotation then
201201
ctx.compilationUnit.hasMacroAnnotations = true
202202
traverseChildren(tree)
203203
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package dotty.tools.dotc
2+
package ast
3+
4+
import tpd.*
5+
import core.Contexts.*
6+
import core.Symbols.*
7+
import util.Property
8+
9+
import scala.collection.mutable
10+
11+
/**
12+
* It is safe to assume that the companion of a tree is in the same scope.
13+
* Therefore, when expanding MacroAnnotations, we will only keep track of
14+
* the trees in the same scope as the current transformed tree
15+
*/
16+
abstract class TreeMapWithTrackedStats extends TreeMapWithImplicits:
17+
18+
import TreeMapWithTrackedStats.*
19+
20+
/** Fetch the corresponding tracked tree for a given symbol */
21+
protected final def getTracked(sym: Symbol)(using Context): Option[MemberDef] =
22+
for trees <- ctx.property(TrackedTrees)
23+
tree <- trees.get(sym)
24+
yield tree
25+
26+
/** Update the tracked trees */
27+
protected final def updateTracked(tree: Tree)(using Context): Tree =
28+
tree match
29+
case tree: MemberDef =>
30+
trackedTrees.update(tree.symbol, tree)
31+
tree
32+
case _ => tree
33+
end updateTracked
34+
35+
/** Process a list of trees and give the priority to trakced trees */
36+
private final def withUpdatedTrackedTrees(stats: List[Tree])(using Context) =
37+
val trackedTrees = TreeMapWithTrackedStats.trackedTrees
38+
stats.mapConserve:
39+
case tree: MemberDef if trackedTrees.contains(tree.symbol) =>
40+
trackedTrees(tree.symbol)
41+
case stat => stat
42+
43+
override def transform(tree: Tree)(using Context): Tree =
44+
tree match
45+
case PackageDef(_, stats) =>
46+
inContext(trackedDefinitionsCtx(stats)): // Step I: Collect and memoize all the definition trees
47+
// Step II: Transform the tree
48+
val pkg@PackageDef(pid, stats) = super.transform(tree): @unchecked
49+
// Step III: Reconcile between the symbols in syms and the tree
50+
cpy.PackageDef(pkg)(pid = pid, stats = withUpdatedTrackedTrees(stats))
51+
case block: Block =>
52+
inContext(trackedDefinitionsCtx(block.stats)): // Step I: Collect all the member definitions in the block
53+
// Step II: Transform the tree
54+
val b@Block(stats, expr) = super.transform(tree): @unchecked
55+
// Step III: Reconcile between the symbols in syms and the tree
56+
cpy.Block(b)(expr = expr, stats = withUpdatedTrackedTrees(stats))
57+
case TypeDef(_, impl: Template) =>
58+
inContext(trackedDefinitionsCtx(impl.body)): // Step I: Collect and memoize all the stats
59+
// Step II: Transform the tree
60+
val newTree@TypeDef(name, impl: Template) = super.transform(tree): @unchecked
61+
// Step III: Reconcile between the symbols in syms and the tree
62+
cpy.TypeDef(newTree)(rhs = cpy.Template(impl)(body = withUpdatedTrackedTrees(impl.body)))
63+
case _ => super.transform(tree)
64+
65+
end TreeMapWithTrackedStats
66+
67+
object TreeMapWithTrackedStats:
68+
private val TrackedTrees = new Property.Key[mutable.Map[Symbol, tpd.MemberDef]]
69+
70+
/** Fetch the tracked trees in the cuurent context */
71+
private def trackedTrees(using Context): mutable.Map[Symbol, MemberDef] =
72+
ctx.property(TrackedTrees).get
73+
74+
/** Build a context and track the provided MemberDef trees */
75+
private def trackedDefinitionsCtx(stats: List[Tree])(using Context): Context =
76+
val treesToTrack = stats.collect { case m: MemberDef => (m.symbol, m) }
77+
ctx.fresh.setProperty(TrackedTrees, mutable.Map(treesToTrack*))

compiler/src/dotty/tools/dotc/transform/Inlining.scala

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
package dotty.tools.dotc
22
package transform
33

4+
import ast.tpd
5+
import ast.Trees.*
6+
import ast.TreeMapWithTrackedStats
47
import core.*
58
import Flags.*
9+
import Decorators.*
610
import Contexts.*
711
import Symbols.*
12+
import Decorators.*
13+
import config.Printers.inlining
14+
import DenotTransformers.IdentityDenotTransformer
15+
import MacroAnnotations.hasMacroAnnotation
16+
import inlines.Inlines
17+
import quoted.*
18+
import staging.StagingLevel
19+
import util.Property
820

9-
import dotty.tools.dotc.ast.tpd
10-
import dotty.tools.dotc.ast.Trees.*
11-
import dotty.tools.dotc.quoted.*
12-
import dotty.tools.dotc.inlines.Inlines
13-
import dotty.tools.dotc.ast.TreeMapWithImplicits
14-
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
15-
import dotty.tools.dotc.staging.StagingLevel
16-
17-
import scala.collection.mutable.ListBuffer
21+
import scala.collection.mutable
1822

1923
/** Inlines all calls to inline methods that are not in an inline method or a quote */
2024
class Inlining extends MacroTransform, IdentityDenotTransformer {
@@ -56,38 +60,21 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
5660

5761
def newTransformer(using Context): Transformer = new Transformer {
5862
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
59-
new InliningTreeMap().transform(tree)
63+
InliningTreeMap().transform(tree)
6064
}
6165

62-
private class InliningTreeMap extends TreeMapWithImplicits {
66+
private class InliningTreeMap extends TreeMapWithTrackedStats {
6367

6468
/** List of top level classes added by macro annotation in a package object.
6569
* These are added to the PackageDef that owns this particular package object.
6670
*/
67-
private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]()
71+
private val newTopClasses = MutableSymbolMap[mutable.ListBuffer[Tree]]()
6872

6973
override def transform(tree: Tree)(using Context): Tree = {
7074
tree match
7175
case tree: MemberDef =>
72-
if tree.symbol.is(Inline) then tree
73-
else if tree.symbol.is(Param) then super.transform(tree)
74-
else if
75-
!tree.symbol.isPrimaryConstructor
76-
&& StagingLevel.level == 0
77-
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
78-
then
79-
val trees = (new MacroAnnotations(self)).expandAnnotations(tree)
80-
val trees1 = trees.map(super.transform)
81-
82-
// Find classes added to the top level from a package object
83-
val (topClasses, trees2) =
84-
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
85-
else (Nil, trees1)
86-
if topClasses.nonEmpty then
87-
newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses
88-
89-
flatTree(trees2)
90-
else super.transform(tree)
76+
// Fetch the latest tracked tree (It might have already been transformed by its companion)
77+
transformMemberDef(getTracked(tree.symbol).getOrElse(tree))
9178
case _: Typed | _: Block =>
9279
super.transform(tree)
9380
case _: PackageDef =>
@@ -113,7 +100,49 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
113100
else Inlines.inlineCall(tree1)
114101
else super.transform(tree)
115102
}
103+
104+
private def transformMemberDef(tree: MemberDef)(using Context) : Tree =
105+
if tree.symbol.is(Inline) then tree
106+
else if tree.symbol.is(Param) then
107+
super.transform(tree)
108+
else if
109+
!tree.symbol.isPrimaryConstructor
110+
&& StagingLevel.level == 0
111+
&& tree.symbol.hasMacroAnnotation
112+
then
113+
// Fetch the companion's tree
114+
val companionSym =
115+
if tree.symbol.is(ModuleClass) then tree.symbol.companionClass
116+
else if tree.symbol.is(ModuleVal) then NoSymbol
117+
else tree.symbol.companionModule.moduleClass
118+
119+
// Expand and process MacroAnnotations
120+
val companion = getTracked(companionSym)
121+
val (trees, newCompanion) = MacroAnnotations.expandAnnotations(tree, companion)
122+
123+
// Enter the new symbols & Update the tracked trees
124+
(newCompanion.toList ::: trees).foreach: tree =>
125+
MacroAnnotations.enterMissingSymbols(tree, self)
126+
127+
// Perform inlining on the expansion of the annotations
128+
val trees1 = trees.map(super.transform)
129+
trees1.foreach(updateTracked)
130+
if newCompanion ne companion then
131+
newCompanion.map(super.transform).foreach(updateTracked)
132+
133+
// Find classes added to the top level from a package object
134+
val (topClasses, trees2) =
135+
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
136+
else (Nil, trees1)
137+
if topClasses.nonEmpty then
138+
newTopClasses.getOrElseUpdate(ctx.owner.owner, new mutable.ListBuffer) ++= topClasses
139+
flatTree(trees2)
140+
else
141+
updateTracked(super.transform(tree))
142+
end transformMemberDef
143+
116144
}
145+
117146
}
118147

119148
object Inlining:

0 commit comments

Comments
 (0)