Skip to content

Inlined non-eta-expanded trees message #11537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
object DefDef extends DefDefModule:
def apply(symbol: Symbol, rhsFn: List[List[Tree]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss =>
yCheckedOwners(rhsFn(prefss), symbol).getOrElse(tpd.EmptyTree)
yCheckedOwners(yCheckValidExpr(rhsFn(prefss)), symbol).getOrElse(tpd.EmptyTree)
))
def copy(original: Tree)(name: String, paramss: List[ParamClause], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
Expand Down Expand Up @@ -293,9 +293,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object ValDef extends ValDefModule:
def apply(symbol: Symbol, rhs: Option[Term]): ValDef =
tpd.ValDef(symbol.asTerm, yCheckedOwners(rhs, symbol).getOrElse(tpd.EmptyTree))
tpd.ValDef(symbol.asTerm, yCheckedOwners(yCheckValidExpr(rhs), symbol).getOrElse(tpd.EmptyTree))
def copy(original: Tree)(name: String, tpt: TypeTree, rhs: Option[Term]): ValDef =
tpd.cpy.ValDef(original)(name.toTermName, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
tpd.cpy.ValDef(original)(name.toTermName, tpt, yCheckedOwners(yCheckValidExpr(rhs), original.symbol).getOrElse(tpd.EmptyTree))
def unapply(vdef: ValDef): (String, TypeTree, Option[Term]) =
(vdef.name.toString, vdef.tpt, optional(vdef.rhs))

Expand Down Expand Up @@ -568,9 +568,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object NamedArg extends NamedArgModule:
def apply(name: String, arg: Term): NamedArg =
withDefaultPos(tpd.NamedArg(name.toTermName, arg))
withDefaultPos(tpd.NamedArg(name.toTermName, yCheckValidExpr(arg)))
def copy(original: Tree)(name: String, arg: Term): NamedArg =
tpd.cpy.NamedArg(original)(name.toTermName, arg)
tpd.cpy.NamedArg(original)(name.toTermName, yCheckValidExpr(arg))
def unapply(x: NamedArg): (String, Term) =
(x.name.toString, x.value)
end NamedArg
Expand All @@ -592,8 +592,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Apply extends ApplyModule:
def apply(fun: Term, args: List[Term]): Apply =
yCheckValidExprs(args)
withDefaultPos(tpd.Apply(fun, args))
def copy(original: Tree)(fun: Term, args: List[Term]): Apply =
yCheckValidExprs(args)
tpd.cpy.Apply(original)(fun, args)
def unapply(x: Apply): (Term, List[Term]) =
(x.fun, x.args)
Expand Down Expand Up @@ -665,9 +667,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Typed extends TypedModule:
def apply(expr: Term, tpt: TypeTree): Typed =
withDefaultPos(tpd.Typed(expr, tpt))
withDefaultPos(tpd.Typed(yCheckValidExpr(expr), tpt))
def copy(original: Tree)(expr: Term, tpt: TypeTree): Typed =
tpd.cpy.Typed(original)(expr, tpt)
tpd.cpy.Typed(original)(yCheckValidExpr(expr), tpt)
def unapply(x: Typed): (Term, TypeTree) =
(x.expr, x.tpt)
end Typed
Expand All @@ -689,9 +691,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Assign extends AssignModule:
def apply(lhs: Term, rhs: Term): Assign =
withDefaultPos(tpd.Assign(lhs, rhs))
withDefaultPos(tpd.Assign(lhs, yCheckValidExpr(rhs)))
def copy(original: Tree)(lhs: Term, rhs: Term): Assign =
tpd.cpy.Assign(original)(lhs, rhs)
tpd.cpy.Assign(original)(lhs, yCheckValidExpr(rhs))
def unapply(x: Assign): (Term, Term) =
(x.lhs, x.rhs)
end Assign
Expand Down Expand Up @@ -754,7 +756,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
object Lambda extends LambdaModule:
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head.map(withDefaultPos)), meth))
tpd.Closure(meth, tss => yCheckedOwners(yCheckValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, TermParamClause(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
Expand All @@ -774,9 +776,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object If extends IfModule:
def apply(cond: Term, thenp: Term, elsep: Term): If =
withDefaultPos(tpd.If(cond, thenp, elsep))
withDefaultPos(tpd.If(yCheckValidExpr(cond), yCheckValidExpr(thenp), yCheckValidExpr(elsep)))
def copy(original: Tree)(cond: Term, thenp: Term, elsep: Term): If =
tpd.cpy.If(original)(cond, thenp, elsep)
tpd.cpy.If(original)(yCheckValidExpr(cond), yCheckValidExpr(thenp), yCheckValidExpr(elsep))
def unapply(tree: If): (Term, Term, Term) =
(tree.cond, tree.thenp, tree.elsep)
end If
Expand All @@ -800,10 +802,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Match extends MatchModule:
def apply(selector: Term, cases: List[CaseDef]): Match =
withDefaultPos(tpd.Match(selector, cases))
withDefaultPos(tpd.Match(yCheckValidExpr(selector), cases))

def copy(original: Tree)(selector: Term, cases: List[CaseDef]): Match =
tpd.cpy.Match(original)(selector, cases)
tpd.cpy.Match(original)(yCheckValidExpr(selector), cases)

def unapply(x: Match): (Term, List[CaseDef]) =
(x.scrutinee, x.cases)
Expand Down Expand Up @@ -850,9 +852,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Try extends TryModule:
def apply(expr: Term, cases: List[CaseDef], finalizer: Option[Term]): Try =
withDefaultPos(tpd.Try(expr, cases, finalizer.getOrElse(tpd.EmptyTree)))
withDefaultPos(tpd.Try(yCheckValidExpr(expr), cases, finalizer.getOrElse(tpd.EmptyTree)))
def copy(original: Tree)(expr: Term, cases: List[CaseDef], finalizer: Option[Term]): Try =
tpd.cpy.Try(original)(expr, cases, finalizer.getOrElse(tpd.EmptyTree))
tpd.cpy.Try(original)(yCheckValidExpr(expr), cases, finalizer.getOrElse(tpd.EmptyTree))
def unapply(x: Try): (Term, List[CaseDef], Option[Term]) =
(x.body, x.cases, optional(x.finalizer))
end Try
Expand All @@ -875,9 +877,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Return extends ReturnModule:
def apply(expr: Term, from: Symbol): Return =
withDefaultPos(tpd.Return(expr, from))
withDefaultPos(tpd.Return(yCheckValidExpr(expr), from))
def copy(original: Tree)(expr: Term, from: Symbol): Return =
tpd.cpy.Return(original)(expr, tpd.ref(from))
tpd.cpy.Return(original)(yCheckValidExpr(expr), tpd.ref(from))
def unapply(x: Return): (Term, Symbol) =
(x.expr, x.from.symbol)
end Return
Expand All @@ -899,8 +901,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Repeated extends RepeatedModule:
def apply(elems: List[Term], elemtpt: TypeTree): Repeated =
yCheckValidExprs(elems)
withDefaultPos(tpd.SeqLiteral(elems, elemtpt))
def copy(original: Tree)(elems: List[Term], elemtpt: TypeTree): Repeated =
yCheckValidExprs(elems)
tpd.cpy.SeqLiteral(original)(elems, elemtpt)
def unapply(x: Repeated): (List[Term], TypeTree) =
(x.elems, x.elemtpt)
Expand All @@ -923,9 +927,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Inlined extends InlinedModule:
def apply(call: Option[Tree], bindings: List[Definition], expansion: Term): Inlined =
withDefaultPos(tpd.Inlined(call.getOrElse(tpd.EmptyTree), bindings.map { case b: tpd.MemberDef => b }, expansion))
withDefaultPos(tpd.Inlined(call.getOrElse(tpd.EmptyTree), bindings.map { case b: tpd.MemberDef => b }, yCheckValidExpr(expansion)))
def copy(original: Tree)(call: Option[Tree], bindings: List[Definition], expansion: Term): Inlined =
tpd.cpy.Inlined(original)(call.getOrElse(tpd.EmptyTree), bindings.asInstanceOf[List[tpd.MemberDef]], expansion)
tpd.cpy.Inlined(original)(call.getOrElse(tpd.EmptyTree), bindings.asInstanceOf[List[tpd.MemberDef]], yCheckValidExpr(expansion))
def unapply(x: Inlined): (Option[Tree /* Term | TypeTree */], List[Definition], Term) =
(optional(x.call), x.bindings, x.body)
end Inlined
Expand Down Expand Up @@ -978,9 +982,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object While extends WhileModule:
def apply(cond: Term, body: Term): While =
withDefaultPos(tpd.WhileDo(cond, body))
withDefaultPos(tpd.WhileDo(yCheckValidExpr(cond), yCheckValidExpr(body)))
def copy(original: Tree)(cond: Term, body: Term): While =
tpd.cpy.WhileDo(original)(cond, body)
tpd.cpy.WhileDo(original)(yCheckValidExpr(cond), yCheckValidExpr(body))
def unapply(x: While): (Term, Term) =
(x.cond, x.body)
end While
Expand Down Expand Up @@ -2830,6 +2834,18 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
case _ => traverseChildren(t)
}.traverse(tree)

private def yCheckValidExprs(terms: List[Term]): terms.type =
if yCheck then terms.foreach(yCheckValidExpr)
terms
private def yCheckValidExpr(termOpt: Option[Term]): termOpt.type =
if yCheck then termOpt.foreach(yCheckValidExpr)
termOpt
private def yCheckValidExpr(term: Term): term.type =
if yCheck then
assert(!term.tpe.widenDealias.isInstanceOf[dotc.core.Types.MethodicType],
"Reference to a method must be eta-expanded before it is used as an expression: " + term.show)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later: maybe provide a reflection API to perform eta-expansion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems such well-formedness checks are very useful in macros development.

Is it possible to centralize the check for the expanded tree? I guess for debugging, the current approach is more helpful, as it shows the stacktrace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For later: maybe provide a reflection API to perform eta-expansion?

We already have term.etaExpand

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to centralize the check for the expanded tree?

Centralize how? Like in normal yCheck?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Centralize how? Like in normal yCheck?

Yes, something like a TreeMap or ReTyper. But as I mentioned above, this might not be a good idea, as it cannot show the stacktrace how an ill-formed tree is created.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did not work in this case because the crash happened before we finished the transformation in the Inlining phase.

term

object Printer extends PrinterModule:

lazy val TreeCode: Printer[Tree] = new Printer[Tree]:
Expand Down Expand Up @@ -2873,6 +2889,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
Extractors.showConstant(using QuotesImpl.this)(const)

end Printer

end reflect

def unpickleExpr[T](pickled: String | List[String], typeHole: (Int, Seq[Any]) => scala.quoted.Type[?], termHole: (Int, Seq[Any], scala.quoted.Quotes) => scala.quoted.Expr[?]): scala.quoted.Expr[T] =
Expand Down
59 changes: 59 additions & 0 deletions tests/neg-macros/i11483/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package x

import scala.annotation._
import scala.quoted._
import scala.compiletime._


trait CpsMonad[F[_]]:

def pure[T](x:T):F[T]

def impure[T](x:F[T]):T

def map[A,B](x:F[A])(f: A=>B):F[B]


@compileTimeOnly("await should be inside async block")
def await[F[_],T](f:F[T])(using am:CpsMonad[F]):T = ???

inline given conversion[F[_],T](using CpsMonad[F]): Conversion[F[T],T] =
x => await(x)


object X {

inline def process[F[_], T](inline t:T)(using m: CpsMonad[F]):F[T] =
${ processImpl[F,T]('t, 'm) }


def processImpl[F[_]:Type, T:Type](t:Expr[T], m:Expr[CpsMonad[F]])(using Quotes):Expr[F[T]] =
import quotes.reflect._
val r = processTree[F,T](t.asTerm, m.asTerm)
r.asExprOf[F[T]]


def processTree[F[_]:Type, T:Type](using Quotes)(t: quotes.reflect.Term, m: quotes.reflect.Term):quotes.reflect.Term =
import quotes.reflect._
val r: Term = t match
case Inlined(_,List(),body) => processTree[F,T](body, m)
case Inlined(d,bindings,body) =>
Inlined(d,bindings,processTree[F,T](body, m))
case Block(stats,expr) => Block(stats,processTree(expr, m))
case Apply(Apply(TypeApply(Ident("await"),targs),List(body)),List(m)) => body
case Apply(f,List(arg)) =>
val nArg = processTree[F,String](arg, m)
Apply(Apply(TypeApply(Select.unique(m,"map"),
List(Inferred(arg.tpe.widen),Inferred(t.tpe.widen))
),
List(nArg)),
List(f)
)
case Apply(f,List()) =>
Apply(TypeApply(Select.unique(m,"pure"),List(Inferred(t.tpe.widen))),List(t))
case Typed(x,tp) => Typed(processTree(x,m), Inferred(TypeRepr.of[F].appliedTo(tp.tpe)) )
case _ => throw new RuntimeException(s"tree not recoginized: $t")
r


}
26 changes: 26 additions & 0 deletions tests/neg-macros/i11483/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package x

import scala.language.implicitConversions
import scala.concurrent.Future

given FutureAsyncMonad: CpsMonad[Future] with
def pure[T](t:T): Future[T] = ???
def impure[T](t:Future[T]): T = ???
def map[A,B](x:Future[A])(f: A=>B): Future[B] = ???


object Api:

def doSomething():Future[String] =
Future.successful("doSomething")

def println(x:String):Unit =
Console.println(x)


object Main:

def main(args: Array[String]): Unit =
X.process[Future,Unit]{ // error
Api.println(Api.doSomething())
}