Skip to content

Commit 30b35f1

Browse files
committed
Fix #8290: Make Expr.betaReduce give up when it sees a non-function typed closure expression
This commit addresses 2 issues with existing betaReduce behaviour: - when given a non-function typed closure the previous iteration could easily fail to resolve the correct apply method, or even successfully inline the wrong code (see added test cases) - if betaReduce did not successfully inline, it would return a transformed tree. This was fine until the above change made it possible to give up while inside a closureDef, which could insert a type ascription inside the closureDef's block, leading to betaReduce returning invalid trees (the closureDef block can only contain a DefDef and Closure, no type ascriptions). Fixing this issue would add meaningless complexity, so instead this commit changes betaReduce to cleanly give up by returning the function tree unchanged, only generating the code necessary to call it. Note: this change affects a few tests that were checking for betaReduce's slight changes to the function tree. Testing the correctness of this change is done by adding cases to existing tests for betaReduce's treatment of type ascriptions.
1 parent 5b006fb commit 30b35f1

File tree

5 files changed

+73
-28
lines changed

5 files changed

+73
-28
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2050,17 +2050,17 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20502050
}}
20512051
val argVals = argVals0.reverse
20522052
val argRefs = argRefs0.reverse
2053-
def rec(fn: Tree, topAscription: Option[TypeTree]): Tree = fn match {
2053+
val expectedSig = Signature.NotAMethod.prependTermParams(argRefs.map(_.tpe), false)
2054+
def rec(fn: Tree, topAscription: Option[TypeTree]): Option[Tree] = fn match {
20542055
case Typed(expr, tpt) =>
2055-
// we need to retain any type ascriptions we see and:
2056-
// a) if we succeed, ascribe the result type of the ascription to the inlined body
2057-
// b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
2056+
// we need to retain any type ascriptions we see and if we succeed,
2057+
// ascribe the result type of the ascription to the inlined body
20582058
// note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
20592059
rec(expr, topAscription.orElse(Some(tpt)))
20602060
case Inlined(call, bindings, expansion) =>
20612061
// this case must go before closureDef to avoid dropping the inline node
2062-
cpy.Inlined(fn)(call, bindings, rec(expansion, topAscription))
2063-
case closureDef(ddef) =>
2062+
rec(expansion, topAscription).map(cpy.Inlined(fn)(call, bindings, _))
2063+
case cl @ closureDef(ddef) if defn.isFunctionType(cl.tpe) =>
20642064
val paramSyms = ddef.vparamss.head.map(param => param.symbol)
20652065
val paramToVals = paramSyms.zip(argRefs).toMap
20662066
val result = new TreeTypeMap(
@@ -2070,24 +2070,26 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20702070
).transform(ddef.rhs)
20712071
topAscription match {
20722072
case Some(tpt) =>
2073-
// we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
2074-
val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf[MethodType]
2073+
// we checked that this is a plain Function closure, so there will be an apply method with a MethodType
2074+
// and the expected signature based on param types
2075+
val methodType = tpt.tpe.member(nme.apply).atSignature(expectedSig).info.asInstanceOf[MethodType]
20752076
// result might contain paramrefs, so we substitute them with arg termrefs
20762077
val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
2077-
Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
2078+
Some(Typed(result, TypeTree(resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span))
20782079
case None =>
2079-
result
2080+
Some(result)
20802081
}
20812082
case tpd.Block(stats, expr) =>
2082-
seq(stats, rec(expr, topAscription)).withSpan(fn.span)
2083+
rec(expr, topAscription).map(seq(stats, _).withSpan(fn.span))
20832084
case _ =>
2084-
val maybeAscribed = topAscription match {
2085-
case Some(tpt) => Typed(fn, tpt).withSpan(fn.span)
2086-
case None => fn
2087-
}
2088-
maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
2085+
None
2086+
}
2087+
rec(fn, None) match {
2088+
case Some(result) => seq(argVals, result)
2089+
case None =>
2090+
val expectedSig = Signature.NotAMethod.prependTermParams(args.map(_.tpe), false)
2091+
fn.selectWithSig(nme.apply, expectedSig).appliedToArgs(args).withSpan(fn.span)
20892092
}
2090-
seq(argVals, rec(fn, None))
20912093
}
20922094

20932095
/////////////

tests/run-macros/beta-reduce-inline-result.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ run-time: 4
33
compile-time: 1
44
run-time: 1
55
run-time: 5
6+
run-time: 7
7+
run-time: -1
8+
run-time: 9

tests/run-macros/beta-reduce-inline-result/Test_2.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,36 @@ object Test {
1414
inline def dummy4: Int => Int =
1515
???
1616

17+
object I extends (Int => Int) {
18+
def apply(i: Int): i.type = i
19+
}
20+
21+
abstract class II extends (Int => Int) {
22+
val apply = 123
23+
}
24+
25+
inline def dummy5: II =
26+
(i: Int) => i + 1
27+
28+
abstract class III extends (Int => Int) {
29+
def impl(i: Int): Int
30+
def apply(i: Int): Int = -1
31+
}
32+
33+
inline def dummy6: III =
34+
(i: Int) => i + 1
35+
36+
abstract class IV extends (Int => Int) {
37+
def apply(s: String): String
38+
}
39+
40+
abstract class V extends IV {
41+
def apply(s: String): String = "gotcha"
42+
}
43+
44+
inline def dummy7: IV =
45+
{ (i: Int) => i + 1 } : V
46+
1747
def main(argv : Array[String]) : Unit = {
1848
println(code"compile-time: ${Macros.betaReduce(dummy1)(3)}")
1949
println(s"run-time: ${Macros.betaReduce(dummy1)(3)}")
@@ -27,7 +57,21 @@ object Test {
2757
def throwsNotImplemented2 = Macros.betaReduce(dummy4)(6)
2858

2959
// make sure paramref types work when inlining is not possible
30-
println(s"run-time: ${Macros.betaReduce(Predef.identity)(5)}")
60+
println(s"run-time: ${Macros.betaReduce(I)(5)}")
61+
62+
// -- cases below are non-function types, which are currently not inlined for simplicity but may be in the future
63+
// (also, this tests that we return something valid when we see a closure that we can't inline)
64+
65+
// A non-function type with an apply value that can be confused with the apply method.
66+
println(s"run-time: ${Macros.betaReduce(dummy5)(6)}")
67+
68+
// should print -1 (without inlining), because the apparent apply method actually
69+
// has nothing to do with the function literal
70+
println(s"run-time: ${Macros.betaReduce(dummy6)(7)}")
71+
72+
// the literal does contain the implementation of the apply method, but there are two abstract apply methods
73+
// in the outermost abstract type
74+
println(s"run-time: ${Macros.betaReduce(dummy7)(8)}")
3175
}
3276
}
3377

tests/run-macros/quote-inline-function.check

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@ Normal function
33
var i: scala.Int = 0
44
val j: scala.Int = 5
55
while (i.<(j)) {
6-
val x$1: scala.Int = i
7-
f.apply(x$1)
6+
f.apply(i)
87
i = i.+(1)
98
}
109
while ({
11-
val x$2: scala.Int = i
12-
f.apply(x$2)
10+
f.apply(i)
1311
i = i.+(1)
1412
i.<(j)
1513
}) ()
@@ -20,13 +18,11 @@ By name function
2018
var i: scala.Int = 0
2119
val j: scala.Int = 5
2220
while (i.<(j)) {
23-
val x$3: scala.Int = i
24-
f.apply(x$3)
21+
f.apply(i)
2522
i = i.+(1)
2623
}
2724
while ({
28-
val x$4: scala.Int = i
29-
f.apply(x$4)
25+
f.apply(i)
3026
i = i.+(1)
3127
i.<(j)
3228
}) ()

tests/run-staging/i3876-c.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66

77
(f: scala.Function1[scala.Int, scala.Int] {
88
def apply(x: scala.Int): scala.Int
9-
}).apply(3)
10-
}
9+
})
10+
}.apply(3)

0 commit comments

Comments
 (0)