Skip to content

Enable returning classes from MacroAnnotations (part 3) #16534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
}

private class InliningTreeMap extends TreeMapWithImplicits {

/** List of top level classes added by macro annotation in a package object.
* These are added the PackageDef that owns this particular package object.
*/
private val topClasses = new collection.mutable.ListBuffer[Tree]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I think this isn't sufficient because package objects can be nested:

package foo {
  val x = 1
  package bar {
    val y = 2
  }
}

Instead, maybe the MemberDef case of transform should return a Thicket with the top-level classes, and we should add an extra case to transform to handle the package object module class itself, where we should also return a Thicket with the top-level classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This use case was considered and works. I added tests for it in tests/run-macros/annot-add-global-class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that after after post typer the tree is

package foo {
  package bar {
    val y = 2
  }
  val x = 1
}

This implies that nested classes are processed first and the buffer never overlaps and is emptied just after transforming the nested package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implies that nested classes are processed first and the buffer never overlaps and is emptied just after transforming the nested package.

This is subtle, so this precondition should be documented in the code (and ideally checked somewhere, in case it breaks)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found cases where this precondition does not hold. I updated the implementation to handle such cases.


override def transform(tree: Tree)(using Context): Tree = {
tree match
case tree: MemberDef =>
Expand All @@ -74,7 +80,15 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
val trees = new MacroAnnotations(thisPhase).expandAnnotations(tree)
flatTree(trees.map(super.transform))
val trees1 = trees.map(super.transform)

// Find classes added to the top level from a package object
val (topClasses0, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
topClasses ++= topClasses0

flatTree(trees2)
else super.transform(tree)
case _: Typed | _: Block =>
super.transform(tree)
Expand All @@ -86,6 +100,14 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
super.transform(tree)(using StagingContext.quoteContext)
case _: GenericApply if tree.symbol.isExprSplice =>
super.transform(tree)(using StagingContext.spliceContext)
case _: PackageDef =>
super.transform(tree) match
case tree1: PackageDef if !topClasses.isEmpty =>
topClasses ++= tree1.stats
val newStats = topClasses.result()
topClasses.clear()
cpy.PackageDef(tree1)(tree1.pid, newStats)
case tree1 => tree1
case _ =>
super.transform(tree)
}
Expand Down
10 changes: 4 additions & 6 deletions compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,15 @@ class MacroAnnotations(thisPhase: DenotTransformer):
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])

/** Check that this tree can be added by the macro annotation and enter it if needed */
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
private def checkAndEnter(newTree: DefTree, annotated: Symbol, annot: Annotation)(using Context) =
val sym = newTree.symbol
if sym.isClass then
report.error(i"macro annotation returning a `class` is not yet supported. $annot tried to add $sym", annot.tree)
else if sym.isType then
if sym.isType && !sym.isClass then
report.error(i"macro annotation cannot return a `type`. $annot tried to add $sym", annot.tree)
else if sym.owner != annotated.owner then
else if sym.owner != annotated.owner && !(annotated.owner.isPackageObject && (sym.isClass || sym.is(Module)) && sym.owner == annotated.owner.owner) then
report.error(i"macro annotation $annot added $sym with an inconsistent owner. Expected it to be owned by ${annotated.owner} but was owned by ${sym.owner}.", annot.tree)
else if annotated.isClass && annotated.owner.is(Package) /*&& !sym.isClass*/ then
report.error(i"macro annotation can not add top-level ${sym.showKind}. $annot tried to add $sym.", annot.tree)
else
else if !sym.is(Module) then // To avoid entering it twice
sym.enteredAfter(thisPhase)

object MacroAnnotations:
Expand Down
26 changes: 21 additions & 5 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Annotations
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds
import dotty.tools.dotc.core.NameOps._
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.quoted.reflect._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.NoCompilationUnit

import dotty.tools.dotc.quoted.{MacroExpansion, PickledQuotes}
import dotty.tools.dotc.quoted.MacroExpansion
import dotty.tools.dotc.quoted.PickledQuotes
import dotty.tools.dotc.quoted.reflect._

import scala.quoted.runtime.{QuoteUnpickler, QuoteMatching}
import scala.quoted.runtime.impl.printers._
Expand Down Expand Up @@ -2481,6 +2482,21 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
for sym <- decls(cls) do cls.enter(sym)
cls

def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
val mod = dotc.core.Symbols.newCompleteModuleSymbol(
owner,
name.toTermName,
modFlags | Flags.Final | Flags.Lazy | Flags.Module,
clsFlags | Flags.Final | Flags.Module,
parents.asInstanceOf, // FIXME
dotc.core.Scopes.newScope,
privateWithin)
val cls = mod.moduleClass.asClass
cls.enter(dotc.core.Symbols.newConstructor(cls, dotc.core.Flags.Synthetic, Nil, Nil))
for sym <- decls(cls) do cls.enter(sym)
mod

def newMethod(owner: Symbol, name: String, tpe: TypeRepr): Symbol =
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
Expand Down
44 changes: 34 additions & 10 deletions library/src/scala/annotation/MacroAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ package annotation

import scala.quoted._

/** Base trait for macro annotation that will transform a definition */
/** Base trait for macro annotation implementation.
* Macro annotations can transform definitions and add new definitions.
*
* See: `MacroAnnotation.transform`
*
* @syntax markdown
*/
@experimental
trait MacroAnnotation extends StaticAnnotation:

/** Transform the `tree` definition and add other definitions
/** Transform the `tree` definition and add new definitions
*
* This method takes as argument the annotated definition.
* It returns a non-empty list containing the modified version of the annotated definition.
* The new tree for the definition must use the original symbol.
* New definitions can be added to the list before or after the transformed definitions, this order
* will be retained.
* will be retained. New definitions will not be visible from outside the macro expansion.
*
* All definitions in the result must have the same owner. The owner can be recovered from `tree.symbol.owner`.
* #### Restrictions
* - All definitions in the result must have the same owner. The owner can be recovered from `Symbol.spliceOwner`.
* - Special case: an annotated top-level `def`, `val`, `var`, `lazy val` can return a `class`/`object`
definition that is owned by the package or package object.
* - Can not return a `type`.
* - Annotated top-level `class`/`object` can not return top-level `def`, `val`, `var`, `lazy val`.
* - Can not see new definition in user written code.
*
* The result cannot add new `class`, `object` or `type` definition. This limitation will be relaxed in the future.
* #### Good practices
* - Make your new definitions private if you can.
* - New definitions added as class members should use a fresh name (`Symbol.freshName`) to avoid collisions.
* - New top-level definitions should use a fresh name (`Symbol.freshName`) that includes the name of the annotated
* member as a prefix to avoid collisions of definitions added in other files.
*
* IMPORTANT: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
* **IMPORTANT**: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
*
* Example 1:
* #### Example 1
* This example shows how to modify a `def` and add a `val` next to it using a macro annotation.
* ```scala
* import scala.quoted.*
Expand Down Expand Up @@ -54,7 +70,10 @@ trait MacroAnnotation extends StaticAnnotation:
* List(tree)
* ```
* with this macro annotation a user can write
* ```scala sc:nocompile
* ```scala
* //{
* class memoize extends scala.annotation.StaticAnnotation
* //}
* @memoize
* def fib(n: Int): Int =
* println(s"compute fib of $n")
Expand All @@ -74,7 +93,7 @@ trait MacroAnnotation extends StaticAnnotation:
* )
* ```
*
* Example 2:
* #### Example 2
* This example shows how to modify a `class` using a macro annotation.
* It shows how to override inherited members and add new ones.
* ```scala
Expand Down Expand Up @@ -164,7 +183,10 @@ trait MacroAnnotation extends StaticAnnotation:
* }
* ```
* with this macro annotation a user can write
* ```scala sc:nocompile
* ```scala
* //{
* class equals extends scala.annotation.StaticAnnotation
* //}
* @equals class User(val name: String, val id: Int)
* ```
* and the macro will modify the class definition to generate the following code
Expand All @@ -184,5 +206,7 @@ trait MacroAnnotation extends StaticAnnotation:
*
* @param Quotes Implicit instance of Quotes used for tree reflection
* @param tree Tree that will be transformed
*
* @syntax markdown
*/
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition]
63 changes: 62 additions & 1 deletion library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3638,8 +3638,69 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*/
// TODO: add flags and privateWithin
@experimental def newClass(parent: Symbol, name: String, parents: List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr]): Symbol

/** Generates a new module symbol with an associated module class symbol.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/** Generates a new module symbol with an associated module class symbol.
/** Generates a new module symbol with an associated module class symbol,
* this is equivalent to an `object` declaration in source code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this change was not applied.

* This returns the module symbol. The module class can be accessed calling `moduleClass` on this symbol.
*
* Example usage:
* ```scala
* //{
* given Quotes = ???
* import quotes.reflect._
* //}
* val moduleName: String = Symbol.freshName("MyModule")
* val parents = List(TypeTree.of[Object])
* def decls(cls: Symbol): List[Symbol] =
* List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.EmptyFlags, Symbol.noSymbol))
*
* val mod = Symbol.newModule(Symbol.spliceOwner, moduleName, Flags.EmptyFlags, Flags.EmptyFlags, parents.map(_.tpe), decls, Symbol.noSymbol)
* val cls = mod.moduleClass
* val runSym = cls.declaredMethod("run").head
*
* val runDef = DefDef(runSym, _ => Some('{ println("run") }.asTerm))
* val clsDef = ClassDef(cls, parents, body = List(runDef))
* val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)
* val modVal = ValDef(mod, Some(newCls))
* val modDef = List(modVal, clsDef)
*
* val callRun = Apply(Select(Ref(mod), runSym), Nil)
*
* Block(modDef, callRun)
* ```
* constructs the equivalent to
* ```scala
* //{
* given Quotes = ???
* import quotes.reflect._
* //}
* '{
* object MyModule$macro$1 extends Object:
* def run(): Unit = println("run")
* MyModule$macro$1.run()
* }
* ```
*
* @param parent The owner of the class
* @param name The name of the class
* @param modFlags extra flags to with which the module symbol should be constructed
* @param clsFlags extra flags to with which the module class symbol should be constructed
* @param parents The parent classes of the class. The first parent must not be a trait.
* @param decls The member declarations of the module provided the symbol of this class
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
*
* This symbol starts without an accompanying definition.
* It is the meta-programmer's responsibility to provide exactly one corresponding definition by passing
* this symbol to the ClassDef and ValDef constructor.
*
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*
* @syntax markdown
*/
@experimental def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol

/** Generates a new method symbol with the given parent, name and type.
*
* To define a member method of a class, use the `newMethod` within the `decls` function of `newClass`.
Expand Down Expand Up @@ -4217,7 +4278,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
// FLAGS //
///////////////

/** FlagSet of a Symbol */
/** Flags of a Symbol */
type Flags

/** Module object of `type Flags` */
Expand Down
13 changes: 13 additions & 0 deletions tests/neg-macros/annot-mod-top-method-add-top-method/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.annotation.{experimental, MacroAnnotation}
import scala.quoted._
import scala.collection.mutable

@experimental
// Assumes annotation is on top level def or val
class addTopLevelMethodOutsidePackageObject extends MacroAnnotation:
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
import quotes.reflect._
val methType = MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Int])
val methSym = Symbol.newMethod(Symbol.spliceOwner.owner, Symbol.freshName("toLevelMethod"), methType, Flags.EmptyFlags, Symbol.noSymbol)
val methDef = ValDef(methSym, Some(Literal(IntConstant(1))))
List(methDef, tree)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@addTopLevelMethodOutsidePackageObject // error
def foo = 1

@addTopLevelMethodOutsidePackageObject // error
val bar = 1
12 changes: 12 additions & 0 deletions tests/neg-macros/annot-mod-top-method-add-top-val/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import scala.annotation.{experimental, MacroAnnotation}
import scala.quoted._
import scala.collection.mutable

@experimental
// Assumes annotation is on top level def or val
class addTopLevelValOutsidePackageObject extends MacroAnnotation:
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
import quotes.reflect._
val valSym = Symbol.newVal(Symbol.spliceOwner.owner, Symbol.freshName("toLevelVal"), TypeRepr.of[Int], Flags.EmptyFlags, Symbol.noSymbol)
val valDef = ValDef(valSym, Some(Literal(IntConstant(1))))
List(valDef, tree)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@addTopLevelValOutsidePackageObject // error
def foo = 1

@addTopLevelValOutsidePackageObject // error
val bar = 1
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ val experimentalDefinitionInLibrary = Set(
// Need experimental annotation macros to check that design works.
"scala.quoted.Quotes.reflectModule.ClassDefModule.apply",
"scala.quoted.Quotes.reflectModule.SymbolModule.newClass",
"scala.quoted.Quotes.reflectModule.SymbolModule.newModule",
"scala.quoted.Quotes.reflectModule.SymbolModule.freshName",
"scala.quoted.Quotes.reflectModule.SymbolMethods.info",

Expand Down
4 changes: 4 additions & 0 deletions tests/run-macros/annot-add-global-class.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
macro generated main
executed in: Bar$macro$1
macro generated main
executed in: Bar$macro$2
28 changes: 28 additions & 0 deletions tests/run-macros/annot-add-global-class/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import scala.annotation.{experimental, MacroAnnotation}
import scala.quoted._
import scala.collection.mutable

@experimental
class addClass extends MacroAnnotation:
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
import quotes.reflect._
tree match
case DefDef(name, List(TermParamClause(Nil)), tpt, Some(rhs)) =>
val parents = List(TypeTree.of[Object])
def decls(cls: Symbol): List[Symbol] =
List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.EmptyFlags, Symbol.noSymbol))

val newClassName = Symbol.freshName("Bar")
val cls = Symbol.newClass(Symbol.spliceOwner.owner, newClassName, parents = parents.map(_.tpe), decls, selfType = None)
val runSym = cls.declaredMethod("run").head

val runDef = DefDef(runSym, _ => Some(rhs))
val clsDef = ClassDef(cls, parents, body = List(runDef))

val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)

val newDef = DefDef.copy(tree)(name, List(TermParamClause(Nil)), tpt, Some(Apply(Select(newCls, runSym), Nil)))
List(clsDef, newDef)
case _ =>
report.error("Annotation only supports `def` with one argument")
List(tree)
25 changes: 25 additions & 0 deletions tests/run-macros/annot-add-global-class/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@addClass def foo(): Unit =
println("macro generated main")
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
//> class Baz$macro$1 extends Object {
//> def run() =
//> println("macro generated main")
//> println("executed in: " + getClass.getName)
//> }
//> def foo(): Unit =
//> new Baz$macro$1.run

@addClass def bar(): Unit =
println("macro generated main")
println("executed in: " + (new Throwable().getStackTrace().head.getClassName))
//> class Baz$macro$2 extends Object {
//> def run() =
//> println("macro generated main")
//> println("executed in: " + getClass.getName)
//> }
//> def foo(): Unit =
//> new Baz$macro$2.run

@main def Test(): Unit =
foo()
bar()
4 changes: 4 additions & 0 deletions tests/run-macros/annot-add-global-object.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
macro generated main
executed in: Test_2$package$Bar$macro$1$
macro generated main
executed in: Test_2$package$Bar$macro$2$
Loading