Skip to content

Fix inline parameters lifted for 0 to 1 #4777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,19 @@ class Definitions {
lazy val QuotedType_applyR = QuotedTypeModule.requiredMethodRef(nme.apply)
def QuotedType_apply(implicit ctx: Context) = QuotedType_applyR.symbol

lazy val QuotedLiftableModule = ctx.requiredModule("scala.quoted.Liftable")
def QuotedLiftableModuleClass(implicit ctx: Context) = QuotedLiftableModule.asClass

def QuotedLiftable_BooleanIsLiftable = QuotedLiftableModule.requiredMethodRef("BooleanIsLiftable")
def QuotedLiftable_ByteIsLiftable = QuotedLiftableModule.requiredMethodRef("ByteIsLiftable")
def QuotedLiftable_CharIsLiftable = QuotedLiftableModule.requiredMethodRef("CharIsLiftable")
def QuotedLiftable_ShortIsLiftable = QuotedLiftableModule.requiredMethodRef("ShortIsLiftable")
def QuotedLiftable_IntIsLiftable = QuotedLiftableModule.requiredMethodRef("IntIsLiftable")
def QuotedLiftable_LongIsLiftable = QuotedLiftableModule.requiredMethodRef("LongIsLiftable")
def QuotedLiftable_FloatIsLiftable = QuotedLiftableModule.requiredMethodRef("FloatIsLiftable")
def QuotedLiftable_DoubleIsLiftable = QuotedLiftableModule.requiredMethodRef("DoubleIsLiftable")
def QuotedLiftable_StringIsLiftable = QuotedLiftableModule.requiredMethodRef("StringIsLiftable")

lazy val QuotedLiftableType = ctx.requiredClassRef("scala.quoted.Liftable")
def QuotedLiftableClass(implicit ctx: Context) = QuotedLiftableType.symbol.asClass

Expand Down
33 changes: 18 additions & 15 deletions compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ReifyQuotes extends MacroTransformWithImplicits with InfoTransformer {
if (isStage0Value(body.symbol)) {
// Optimization: avoid the full conversion when capturing inlined `x`
// in '{ x } to '{ x$1.toExpr.unary_~ } and go directly to `x$1.toExpr`
liftValue(capturers(body.symbol)(body))
liftInlineParamValue(capturers(body.symbol)(body))
} else {
// Optimization: avoid the full conversion when capturing `x`
// in '{ x } to '{ x$1.unary_~ } and go directly to `x$1`
Expand Down Expand Up @@ -577,7 +577,7 @@ class ReifyQuotes extends MacroTransformWithImplicits with InfoTransformer {
splice(t.select(if (tree.isTerm) nme.UNARY_~ else tpnme.UNARY_~))
if (!isStage0Value(tree.symbol)) captureAndSplice(capturer(tree))
else if (level == 0) capturer(tree)
else captureAndSplice(liftValue(capturer(tree)))
else captureAndSplice(liftInlineParamValue(capturer(tree)))
case Block(stats, _) =>
val last = enteredSyms
stats.foreach(markDef)
Expand Down Expand Up @@ -632,19 +632,22 @@ class ReifyQuotes extends MacroTransformWithImplicits with InfoTransformer {
}
}

private def liftValue(tree: Tree)(implicit ctx: Context): Tree = {
val reqType = defn.QuotedLiftableType.appliedTo(tree.tpe.widen)
val liftable = ctx.typer.inferImplicitArg(reqType, tree.pos)
liftable.tpe match {
case fail: SearchFailureType =>
ctx.error(i"""
|
| The access would be accepted with the right Liftable, but
| ${ctx.typer.missingArgMsg(liftable, reqType, "")}""")
EmptyTree
case _ =>
liftable.select("toExpr".toTermName).appliedTo(tree)
}
/** Takes a reference to an inline parameter `tree` and lifts it to an Expr */
private def liftInlineParamValue(tree: Tree)(implicit ctx: Context): Tree = {
val tpSym = tree.tpe.widenDealias.classSymbol

val lifter =
if (tpSym eq defn.BooleanClass) defn.QuotedLiftable_BooleanIsLiftable
else if (tpSym eq defn.ByteClass) defn.QuotedLiftable_ByteIsLiftable
else if (tpSym eq defn.CharClass) defn.QuotedLiftable_CharIsLiftable
else if (tpSym eq defn.ShortClass) defn.QuotedLiftable_ShortIsLiftable
else if (tpSym eq defn.IntClass) defn.QuotedLiftable_IntIsLiftable
else if (tpSym eq defn.LongClass) defn.QuotedLiftable_LongIsLiftable
else if (tpSym eq defn.FloatClass) defn.QuotedLiftable_FloatIsLiftable
else if (tpSym eq defn.DoubleClass) defn.QuotedLiftable_DoubleIsLiftable
else defn.QuotedLiftable_StringIsLiftable

ref(lifter).select("toExpr".toTermName).appliedTo(tree)
}

private def isStage0Value(sym: Symbol)(implicit ctx: Context): Boolean =
Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/quoted/Liftable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ abstract class Liftable[T] {
*/
object Liftable {
implicit def BooleanIsLiftable: Liftable[Boolean] = (x: Boolean) => liftedExpr(x)
implicit def ByteLiftable: Liftable[Byte] = (x: Byte) => liftedExpr(x)
implicit def ByteIsLiftable: Liftable[Byte] = (x: Byte) => liftedExpr(x)
implicit def CharIsLiftable: Liftable[Char] = (x: Char) => liftedExpr(x)
implicit def ShortIsLiftable: Liftable[Short] = (x: Short) => liftedExpr(x)
implicit def IntIsLiftable: Liftable[Int] = (x: Int) => liftedExpr(x)
Expand Down
7 changes: 7 additions & 0 deletions tests/pos/quote-lift-inline-params-b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import scala.quoted.Expr
object Macro {
inline def foo(inline n: Int): Int = ~{
import quoted.Liftable.{IntIsLiftable => _}
'(n)
}
}
4 changes: 4 additions & 0 deletions tests/pos/quote-lift-inline-params/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

object App {
Macro.foo(3)
}
7 changes: 7 additions & 0 deletions tests/pos/quote-lift-inline-params/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import scala.quoted.Expr
object Macro {
inline def foo(inline n: Int): Int = ~{
import quoted.Liftable.{IntIsLiftable => _}
'(n)
}
}