Skip to content

Commit a29178c

Browse files
Merge pull request #8251 from fhackett/fhackett-fix-8250
Fix #8250: Allow Expr.betaReduce to drop type ascriptions.
2 parents 8fa5c5c + 0c34994 commit a29178c

File tree

7 files changed

+109
-6
lines changed

7 files changed

+109
-6
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2050,24 +2050,44 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20502050
}}
20512051
val argVals = argVals0.reverse
20522052
val argRefs = argRefs0.reverse
2053-
def rec(fn: Tree): Tree = fn match {
2053+
def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match {
2054+
case Typed(expr, tpt) =>
2055+
// we need to retain any type ascriptions we see and:
2056+
// a) if we succeed, ascribe the result type of the ascription to the inlined body
2057+
// b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
2058+
// note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
2059+
rec(expr, topAscription.orElse(Some(tpt)))
20542060
case Inlined(call, bindings, expansion) =>
20552061
// this case must go before closureDef to avoid dropping the inline node
2056-
cpy.Inlined(fn)(call, bindings, rec(expansion))
2062+
cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription))
20572063
case closureDef(ddef) =>
20582064
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
20592065
val paramToVals = paramSyms.zip(argRefs).toMap
2060-
new TreeTypeMap(
2066+
val result = new TreeTypeMap(
20612067
oldOwners = ddef.symbol :: Nil,
20622068
newOwners = ctx.owner :: Nil,
20632069
treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
20642070
).transform(ddef.rhs)
2071+
topAscription match {
2072+
case Some(tpt) =>
2073+
// we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
2074+
val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType]
2075+
// result might contain paramrefs, so we substitute them with arg termrefs
2076+
val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
2077+
Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
2078+
case None =>
2079+
result
2080+
}
20652081
case tpd.Block(stats, expr) =>
2066-
seq(stats, rec(expr)).withSpan(fn.span)
2082+
seq(stats, rec(expr, topAscription)).withSpan(fn.span)
20672083
case _ =>
2068-
fn.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
2084+
val maybeAscribed = topAscription match {
2085+
case Some(tpt) => Typed(fn, tpt).withSpan(fn.span)
2086+
case None => fn
2087+
}
2088+
maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
20692089
}
2070-
seq(argVals, rec(fn))
2090+
seq(argVals, rec(fn, None))
20712091
}
20722092

20732093
/////////////
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
-- [E007] Type Mismatch Error: tests/neg-macros/beta-reduce-inline-result/Test_2.scala:11:41 ---------------------------
3+
11 | val x2: 4 = Macros.betaReduce(dummy1)(3) // error
4+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5+
| Found: Int
6+
| Required: (4 : Int)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.quoted._
2+
3+
object Macros {
4+
inline def betaReduce[Arg,Result](inline fn: Arg=>Result)(inline arg: Arg): Result =
5+
${ betaReduceImpl('{ fn })('{ arg }) }
6+
7+
def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Result] =
8+
Expr.betaReduce(fn)(arg)
9+
}
10+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
object Test {
3+
4+
inline def dummy1: Int => Int =
5+
(i: Int) => i + 1
6+
7+
inline def dummy2: Int => Int =
8+
(i: Int) => ???
9+
10+
val x1: Int = Macros.betaReduce(dummy1)(3)
11+
val x2: 4 = Macros.betaReduce(dummy1)(3) // error
12+
}
13+
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
compile-time: 4
2+
run-time: 4
3+
compile-time: 1
4+
run-time: 1
5+
run-time: 5
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._
2+
3+
object Macros {
4+
inline def betaReduce[Arg,Result](inline fn : Arg=>Result)(inline arg: Arg): Result =
5+
${ betaReduceImpl('{ fn })('{ arg }) }
6+
7+
def betaReduceImpl[Arg,Result](fn: Expr[Arg=>Result])(arg: Expr[Arg])(using qctx : QuoteContext): Expr[Result] =
8+
Expr.betaReduce(fn)(arg)
9+
10+
inline def betaReduceAdd1[Arg](inline fn: Arg=>Int)(inline arg: Arg): Int =
11+
${ betaReduceAdd1Impl('{ fn })('{ arg }) }
12+
13+
def betaReduceAdd1Impl[Arg](fn: Expr[Arg=>Int])(arg: Expr[Arg])(using qctx: QuoteContext): Expr[Int] =
14+
'{ ${ Expr.betaReduce(fn)(arg) } + 1 }
15+
}
16+
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
inline def dummy1: Int => Int =
6+
(i: Int) => i + 1
7+
8+
inline def dummy2: (i: Int) => i.type =
9+
(i: Int) => i
10+
11+
inline def dummy3: Int => Int =
12+
(i: Int) => ???
13+
14+
inline def dummy4: Int => Int =
15+
???
16+
17+
def main(argv : Array[String]) : Unit = {
18+
println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}")
19+
println(s"run-time: ${Macros.betaReduce(dummy1)(3)}")
20+
println(code"compile-time: ${Macros.betaReduce(dummy2)(1)}")
21+
// paramrefs have to be properly substituted in this case
22+
println(s"run-time: ${Macros.betaReduce(dummy2)(1)}")
23+
24+
// ensure the inlined ??? is ascribed type Int so this compiles
25+
def throwsNotImplemented1 = Macros.betaReduceAdd1(dummy3)(4)
26+
// ensure we handle cases where the (non-inlineable) function itself needs ascribing
27+
def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6)
28+
29+
// make sure paramref types work when inlining is not possible
30+
println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}")
31+
}
32+
}
33+

0 commit comments

Comments
 (0)