Skip to content

Commit 3ad35a8

Browse files
committed
Make SigPoly symbols resistant to withPrefix changes
1 parent f3946bc commit 3ad35a8

File tree

4 files changed

+43
-20
lines changed

4 files changed

+43
-20
lines changed

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,7 @@ class TreeUnpickler(reader: TastyReader,
12381238
val fn = readTerm()
12391239
val methType = readType()
12401240
val args = until(end)(readTerm())
1241-
val sym2 = fn.symbol.copy(info = methType) // symbol not entered (same as in simpleApply)
1242-
val fun2 = fn.withType(sym2.termRef)
1241+
val fun2 = typer.Applications.retypeSignaturePolymorphicFn(fn, methType)
12431242
tpd.Apply(fun2, args)
12441243
case TYPED =>
12451244
val expr = readTerm()

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ object Applications {
341341
val getter = findDefaultGetter(fn, n, testOnly)
342342
if getter.isEmpty then getter
343343
else spliceMeth(getter.withSpan(fn.span), fn)
344+
345+
def retypeSignaturePolymorphicFn(fun: Tree, methType: Type)(using Context): Tree =
346+
val sym1 = fun.symbol
347+
val flags2 = sym1.flags | NonMember // ensures Select typing doesn't let TermRef#withPrefix revert the type
348+
val sym2 = sym1.copy(info = methType, flags = flags2) // symbol not entered, to avoid overload resolution problems
349+
fun.withType(sym2.termRef)
344350
}
345351

346352
trait Applications extends Compatibility {
@@ -950,8 +956,7 @@ trait Applications extends Compatibility {
950956
case resTp if isFullyDefined(resTp, ForceDegree.all) => resTp
951957
case _ => defn.ObjectType
952958
val methType = MethodType(proto.typedArgs().map(_.tpe.widen), resultType)
953-
val sym2 = funRef.symbol.copy(info = methType) // symbol not entered, to avoid overload resolution problems
954-
val fun2 = fun1.withType(sym2.termRef)
959+
val fun2 = Applications.retypeSignaturePolymorphicFn(fun1, methType)
955960
simpleApply(fun2, proto)
956961
case funRef: TermRef =>
957962
val app = ApplyTo(tree, fun1, funRef, proto, pt)

tests/explicit-nulls/pos/i11332.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// scalajs: --skip
2+
import scala.language.unsafeNulls
3+
4+
import java.lang.invoke._, MethodType.methodType
5+
6+
// A copy of tests/run/i11332.scala
7+
// to test the bootstrap minimisation which failed
8+
// (because bootstrap runs under explicit nulls)
9+
class Foo:
10+
def neg(x: Int): Int = -x
11+
12+
val l = MethodHandles.lookup()
13+
val self = new Foo()
14+
15+
val test = // testing as a expression tree - previously derivedSelect broke the type
16+
l
17+
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
18+
.invokeExact(self, 4): Int

tests/run/i11332.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class Foo {
1616
def id[T](x: T): T = x
1717

1818
val l = MethodHandles.lookup()
19+
val self = new Foo()
1920
val mhNeg = l.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
2021
val mhRev = l.findVirtual(classOf[Foo], "rev", methodType(classOf[String], classOf[String]))
2122
val mhOverL = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Long]))
@@ -24,25 +25,25 @@ class Foo {
2425
val mhObj = l.findVirtual(classOf[Foo], "obj", methodType(classOf[Any], classOf[String]))
2526
val mhCL = l.findStatic(classOf[ClassLoader], "getPlatformClassLoader", methodType(classOf[ClassLoader]))
2627

27-
val testNeg1 = assert(-42 == (mhNeg.invokeExact(this, 42): Int))
28-
val testNeg2 = assert(-33 == (mhNeg.invokeExact(this, 33): Int))
28+
val testNeg1 = assert(-42 == (mhNeg.invokeExact(self, 42): Int))
29+
val testNeg2 = assert(-33 == (mhNeg.invokeExact(self, 33): Int))
2930

30-
val testRev1 = assert("oof" == (mhRev.invokeExact(this, "foo"): String))
31-
val testRev2 = assert("rab" == (mhRev.invokeExact(this, "bar"): String))
31+
val testRev1 = assert("oof" == (mhRev.invokeExact(self, "foo"): String))
32+
val testRev2 = assert("rab" == (mhRev.invokeExact(self, "bar"): String))
3233

33-
val testOverL = assert("long" == (mhOverL.invokeExact(this, 1L): String))
34-
val testOVerI = assert("int" == (mhOverI.invokeExact(this, 1): String))
34+
val testOverL = assert("long" == (mhOverL.invokeExact(self, 1L): String))
35+
val testOVerI = assert("int" == (mhOverI.invokeExact(self, 1): String))
3536

36-
val testNeg_tvar = assert(-3 == (id(mhNeg.invokeExact(this, 3)): Int))
37-
val testNeg_obj = expectWrongMethod(mhNeg.invokeExact(this, 4))
37+
val testNeg_tvar = assert(-3 == (id(mhNeg.invokeExact(self, 3)): Int))
38+
val testNeg_obj = expectWrongMethod(mhNeg.invokeExact(self, 4))
3839

39-
val testUnit_exp = { mhUnit.invokeExact(this, "hi"): Unit; () }
40-
val testUnit_val = { val hi2: Unit = mhUnit.invokeExact(this, "hi2"); assert((()) == hi2) }
41-
val testUnit_def = { def hi3: Unit = mhUnit.invokeExact(this, "hi3"); assert((()) == hi3) }
40+
val testUnit_exp = { mhUnit.invokeExact(self, "hi"): Unit; () }
41+
val testUnit_val = { val hi2: Unit = mhUnit.invokeExact(self, "hi2"); assert((()) == hi2) }
42+
val testUnit_def = { def hi3: Unit = mhUnit.invokeExact(self, "hi3"); assert((()) == hi3) }
4243

43-
val testObj_exp = { mhObj.invokeExact(this, "any"); () }
44-
val testObj_val = { val any2 = mhObj.invokeExact(this, "any2"); assert("any2" == any2) }
45-
val testObj_def = { def any3 = mhObj.invokeExact(this, "any3"); assert("any3" == any3) }
44+
val testObj_exp = { mhObj.invokeExact(self, "any"); () }
45+
val testObj_val = { val any2 = mhObj.invokeExact(self, "any2"); assert("any2" == any2) }
46+
val testObj_def = { def any3 = mhObj.invokeExact(self, "any3"); assert("any3" == any3) }
4647

4748
val testCl_pass = assert(null != (mhCL.invoke(): ClassLoader))
4849
val testCl_cast = assert(null != (mhCL.invoke().asInstanceOf[ClassLoader]: ClassLoader))
@@ -51,10 +52,10 @@ class Foo {
5152

5253
val testNeg_inline_obj = expectWrongMethod(l
5354
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
54-
.invokeExact(this, 3))
55+
.invokeExact(self, 3))
5556
val testNeg_inline_pass = assert(-4 == (l
5657
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
57-
.invokeExact(this, 4): Int))
58+
.invokeExact(self, 4): Int))
5859

5960
def expectWrongMethod(op: => Any) = try {
6061
op

0 commit comments

Comments
 (0)