diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala index 201d18e558e9..44c3e0637c68 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala @@ -2042,24 +2042,44 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend }} val argVals = argVals0.reverse val argRefs = argRefs0.reverse - def rec(fn: Tree): Tree = fn match { + def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match { + case Typed(expr, tpt) => + // we need to retain any type ascriptions we see and: + // a) if we succeed, ascribe the result type of the ascription to the inlined body + // b) if we fail, re-ascribe the same type to whatever it was we couldn't inline + // note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects + rec(expr, topAscription.orElse(Some(tpt))) case Inlined(call, bindings, expansion) => // this case must go before closureDef to avoid dropping the inline node - cpy.Inlined(fn)(call, bindings, rec(expansion)) + cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription)) case closureDef(ddef) => val paramSyms = ddef.vparamss.head.map(param => param.symbol) val paramToVals = paramSyms.zip(argRefs).toMap - new TreeTypeMap( + val result = new TreeTypeMap( oldOwners = ddef.symbol :: Nil, newOwners = ctx.owner :: Nil, treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree) ).transform(ddef.rhs) + topAscription match { + case Some(tpt) => + // we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys) + val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType] + // result might contain paramrefs, so we substitute them with arg termrefs + val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe)) + Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span) + case None => + result + } case tpd.Block(stats, expr) => - seq(stats, rec(expr)).withSpan(fn.span) + seq(stats, rec(expr, topAscription)).withSpan(fn.span) case _ => - fn.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span) + val maybeAscribed = topAscription match { + case Some(tpt) => Typed(fn, tpt).withSpan(fn.span) + case None => fn + } + maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span) } - seq(argVals, rec(fn)) + seq(argVals, rec(fn, None)) } ///////////// diff --git a/tests/neg-macros/beta-reduce-inline-result.check b/tests/neg-macros/beta-reduce-inline-result.check new file mode 100644 index 000000000000..08672b15d3a8 --- /dev/null +++ b/tests/neg-macros/beta-reduce-inline-result.check @@ -0,0 +1,6 @@ + +-- [E007] Type Mismatch Error: tests/neg-macros/beta-reduce-inline-result/Test_2.scala:11:41 --------------------------- +11 | val x2: 4 = Macros.betaReduce(dummy1)(3) // error + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | Found: Int + | Required: (4 : Int) diff --git a/tests/neg-macros/beta-reduce-inline-result/Macro_1.scala b/tests/neg-macros/beta-reduce-inline-result/Macro_1.scala new file mode 100644 index 000000000000..900001a24aaf --- /dev/null +++ b/tests/neg-macros/beta-reduce-inline-result/Macro_1.scala @@ -0,0 +1,10 @@ +import scala.quoted._ + +object Macros { + inline def betaReduce[Arg,Result](inline fn: Arg=>Result)(inline arg: Arg): Result = + ${ betaReduceImpl('{ fn })('{ arg }) } + + def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Result] = + Expr.betaReduce(fn)(arg) +} + diff --git a/tests/neg-macros/beta-reduce-inline-result/Test_2.scala b/tests/neg-macros/beta-reduce-inline-result/Test_2.scala new file mode 100644 index 000000000000..86bea7d89167 --- /dev/null +++ b/tests/neg-macros/beta-reduce-inline-result/Test_2.scala @@ -0,0 +1,13 @@ + +object Test { + + inline def dummy1: Int => Int = + (i: Int) => i + 1 + + inline def dummy2: Int => Int = + (i: Int) => ??? + + val x1: Int = Macros.betaReduce(dummy1)(3) + val x2: 4 = Macros.betaReduce(dummy1)(3) // error +} + diff --git a/tests/run-macros/beta-reduce-inline-result.check b/tests/run-macros/beta-reduce-inline-result.check new file mode 100644 index 000000000000..082514df02f7 --- /dev/null +++ b/tests/run-macros/beta-reduce-inline-result.check @@ -0,0 +1,5 @@ +compile-time: 4 +run-time: 4 +compile-time: 1 +run-time: 1 +run-time: 5 diff --git a/tests/run-macros/beta-reduce-inline-result/Macro_1.scala b/tests/run-macros/beta-reduce-inline-result/Macro_1.scala new file mode 100644 index 000000000000..b5a0f2419e8b --- /dev/null +++ b/tests/run-macros/beta-reduce-inline-result/Macro_1.scala @@ -0,0 +1,16 @@ +import scala.quoted._ + +object Macros { + inline def betaReduce[Arg,Result](inline fn : Arg=>Result)(inline arg: Arg): Result = + ${ betaReduceImpl('{ fn })('{ arg }) } + + def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx : QuoteContext): Expr[Result] = + Expr.betaReduce(fn)(arg) + + inline def betaReduceAdd1[Arg](inline fn: Arg=>Int)(inline arg: Arg): Int = + ${ betaReduceAdd1Impl('{ fn })('{ arg }) } + + def betaReduceAdd1Impl[Arg](fn: Expr[Arg=>Int])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Int] = + '{ ${ Expr.betaReduce(fn)(arg) } + 1 } +} + diff --git a/tests/run-macros/beta-reduce-inline-result/Test_2.scala b/tests/run-macros/beta-reduce-inline-result/Test_2.scala new file mode 100644 index 000000000000..978b3e5d2f41 --- /dev/null +++ b/tests/run-macros/beta-reduce-inline-result/Test_2.scala @@ -0,0 +1,33 @@ +import scala.compiletime._ + +object Test { + + inline def dummy1: Int => Int = + (i: Int) => i + 1 + + inline def dummy2: (i: Int) => i.type = + (i: Int) => i + + inline def dummy3: Int => Int = + (i: Int) => ??? + + inline def dummy4: Int => Int = + ??? + + def main(argv : Array[String]) : Unit = { + println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}") + println(s"run-time: ${Macros.betaReduce(dummy1)(3)}") + println(code"compile-time: ${Macros.betaReduce(dummy2)(1)}") + // paramrefs have to be properly substituted in this case + println(s"run-time: ${Macros.betaReduce(dummy2)(1)}") + + // ensure the inlined ??? is ascribed type Int so this compiles + def throwsNotImplemented1 = Macros.betaReduceAdd1(dummy3)(4) + // ensure we handle cases where the (non-inlineable) function itself needs ascribing + def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6) + + // make sure paramref types work when inlining is not possible + println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}") + } +} +