Skip to content

Commit 9ba6547

Browse files
committed
Fix eta expand at erasure to take erased CFTs into account
1 parent 3c20740 commit 9ba6547

File tree

6 files changed

+190
-22
lines changed

6 files changed

+190
-22
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,22 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11041104
if (sym.exists) sym.defTree = tree
11051105
tree
11061106
}
1107+
1108+
def etaExpandCFT(using Context): Tree =
1109+
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
1110+
case defn.ContextFunctionType(argTypes, resType, isErased) =>
1111+
val anonFun = newSymbol(
1112+
ctx.owner, nme.ANON_FUN, Flags.Synthetic | Flags.Method,
1113+
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
1114+
coord = ctx.owner.coord)
1115+
def lambdaBody(refss: List[List[Tree]]) =
1116+
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
1117+
using ctx.withOwner(anonFun))
1118+
Closure(anonFun, lambdaBody)
1119+
.showing(i"expand $tree --> $result")
1120+
case _ =>
1121+
target
1122+
expand(tree, tree.tpe.widen)
11071123
}
11081124

11091125
inline val MapRecursionLimit = 10

compiler/src/dotty/tools/dotc/transform/AccessProxies.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ abstract class AccessProxies {
5151
forwardedArgss.nonEmpty && forwardedArgss.head.nonEmpty) // defensive conditions
5252
accessRef.becomes(forwardedArgss.head.head)
5353
else
54-
accessRef.appliedToTypeTrees(forwardedTpts).appliedToArgss(forwardedArgss)
54+
accessRef
55+
.appliedToTypeTrees(forwardedTpts)
56+
.appliedToArgss(forwardedArgss)
57+
.etaExpandCFT(using ctx.withOwner(accessor))
5558
rhs.withSpan(accessed.span)
5659
})
5760

compiler/src/dotty/tools/dotc/transform/Bridges.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
112112
toBeRemoved += other
113113
}
114114

115-
def bridgeRhs(argss: List[List[Tree]]) = {
115+
def bridgeRhs(argss: List[List[Tree]]) =
116116
assert(argss.tail.isEmpty)
117117
val ref = This(root).select(member)
118-
if (member.info.isParameterless) ref // can happen if `member` is a module
118+
if member.info.isParameterless then ref // can happen if `member` is a module
119+
else if true then Erasure.Boxing.forwarder(ref, argss.head, bridge, member.info.finalResultType, other)
119120
else Erasure.partialApply(ref, argss.head)
120-
}
121121

122122
bridges += DefDef(bridge, bridgeRhs(_).withSpan(bridge.span))
123123
}

compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,26 @@ object ContextFunctionResults:
127127
missingCR(meth.info.finalResultType, contextResultCount(meth))._1
128128
}
129129

130+
/** The rightmost context function type in the result type of `meth`
131+
* that represents `paramCount` curried, non-erased parameters that
132+
* are included in the `contextResultCount` of `meth`.
133+
* Example:
134+
*
135+
* Say we have `def m(x: A): B ?=> (C1, C2, C3) ?=> D ?=> E ?=> F`,
136+
* paramCount == 4, and the contextResultCount of `m` is 3.
137+
* Then we return the type `(C1, C2, C3) ?=> D ?=> E ?=> F`, since this
138+
* type covers the 4 rightmost parameters C1, C2, C3 and D before the
139+
* contextResultCount runs out at E ?=> F.
140+
* Erased parameters are ignored; they contribute nothing to the
141+
* parameter count.
142+
*/
143+
def contextFunctionResultTypeAfter(meth: Symbol, skipCount: Int)(using Context) =
144+
def recur(tp: Type, n: Int): Type =
145+
if n == 0 then tp
146+
else tp match
147+
case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1)
148+
recur(meth.info.finalResultType, skipCount)
149+
130150
/** Should selection `tree` be eliminated since it refers to an `apply`
131151
* node of a context function type whose parameters will end up being
132152
* integrated in the preceding method?

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 94 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -381,26 +381,32 @@ object Erasure {
381381
*/
382382
def adaptToType(tree: Tree, pt: Type)(using Context): Tree = pt match
383383
case _: FunProto | AnyFunctionProto => tree
384-
case _ => tree.tpe.widen match
385-
case mt: MethodType if tree.isTerm =>
386-
if mt.paramInfos.isEmpty then adaptToType(tree.appliedToNone, pt)
387-
else etaExpand(tree, mt, pt)
388-
case tpw =>
389-
if (pt.isInstanceOf[ProtoType] || tree.tpe <:< pt)
390-
tree
391-
else if (tpw.isErasedValueType)
392-
if (pt.isErasedValueType) then
393-
tree.asInstance(pt)
384+
case _ =>
385+
/*val ahead = tree.attachmentOrElse(needsEta, 0)
386+
if ahead > 0 then etaExpand(tree, ahead, pt)
387+
else*/ tree.tpe.widen match
388+
case mt: MethodType if tree.isTerm =>
389+
if mt.paramInfos.isEmpty then adaptToType(tree.appliedToNone, pt)
394390
else
391+
assert(false)
392+
etaExpand(tree, mt, pt)
393+
394+
case tpw =>
395+
if (pt.isInstanceOf[ProtoType] || tree.tpe <:< pt)
396+
tree
397+
else if (tpw.isErasedValueType)
398+
if (pt.isErasedValueType) then
399+
tree.asInstance(pt)
400+
else
401+
adaptToType(box(tree), pt)
402+
else if (pt.isErasedValueType)
403+
adaptToType(unbox(tree, pt), pt)
404+
else if (tpw.isPrimitiveValueType && !pt.isPrimitiveValueType)
395405
adaptToType(box(tree), pt)
396-
else if (pt.isErasedValueType)
397-
adaptToType(unbox(tree, pt), pt)
398-
else if (tpw.isPrimitiveValueType && !pt.isPrimitiveValueType)
399-
adaptToType(box(tree), pt)
400-
else if (pt.isPrimitiveValueType && !tpw.isPrimitiveValueType)
401-
adaptToType(unbox(tree, pt), pt)
402-
else
403-
cast(tree, pt)
406+
else if (pt.isPrimitiveValueType && !tpw.isPrimitiveValueType)
407+
adaptToType(unbox(tree, pt), pt)
408+
else
409+
cast(tree, pt)
404410
end adaptToType
405411

406412
/** The following code:
@@ -534,6 +540,76 @@ object Erasure {
534540
tree
535541
end adaptClosure
536542

543+
/** Eta expand given `tree` that has the given method type `mt`, so that
544+
* it conforms to erased result type `pt`.
545+
* To do this correctly, we have to look at the tree's original pre-erasure
546+
* type and figure out which context function types in its result are
547+
* not yet instantiated.
548+
*/
549+
def forwarder(ref: Tree, args: List[Tree], owner: Symbol, pt: Type, other: Symbol)(using Context): Tree =
550+
val origOwner = ctx.owner
551+
val member = ref.symbol
552+
val memberCount = contextResultCount(member)
553+
if memberCount == 0 then
554+
ref.appliedToTermArgs(args)
555+
else
556+
def expandArgs(args: List[Tree], owner: Symbol)(using Context): List[Tree] = args match
557+
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
558+
owner.info.firstParamTypes.indices.toList.map(n =>
559+
bunchedParam
560+
.select(nme.primitive.arrayApply)
561+
.appliedTo(Literal(Constant(n)))).showing("Expand $args%, % --> $result%, %")
562+
case _ => args
563+
564+
val toAbstract: List[TermSymbol] =
565+
def anonFuns(tp: Type, n: Int, owner: Symbol): List[TermSymbol] =
566+
if n <= 0 then Nil
567+
else
568+
val defn.ContextFunctionType(argTpes, resTpe, isErased) = tp: @unchecked
569+
val anonFun = newSymbol(
570+
owner, nme.ANON_FUN, Flags.Synthetic | Flags.Method,
571+
MethodType(if isErased then Nil else argTpes, resTpe),
572+
coord = owner.coord)
573+
anonFun :: anonFuns(resTpe, n - 1, anonFun)
574+
if memberCount == 0 then Nil
575+
else
576+
val otherCount = contextResultCount(other)
577+
val resType = contextFunctionResultTypeAfter(member, otherCount)(using preErasureCtx)
578+
anonFuns(resType, memberCount - otherCount, owner)
579+
580+
def etaExpand(args: List[Tree], anonFuns: List[TermSymbol], owner: Symbol): Tree =
581+
anonFuns match
582+
case Nil =>
583+
val app = untpd.cpy.Apply(ref)(ref, args)
584+
assert(ctx.typer.isInstanceOf[Erasure.Typer])
585+
ctx.typer.typed(app, pt)
586+
// .changeOwnerAfter(origOwner, ctx.owner, erasurePhase.asInstanceOf[Erasure])
587+
//ref.appliedToTermArgs(args)//.changeOwner(ctx.owner, owner)
588+
case anonFun :: anonFuns1 =>
589+
val origType = anonFun.info
590+
anonFun.info = transformInfo(anonFun, anonFun.info)
591+
inContext(ctx.withOwner(owner)) {
592+
def lambdaBody(refss: List[List[Tree]]) =
593+
val refs :: Nil = refss: @unchecked
594+
val expandedRefs = refs.map(_.withSpan(owner.span.endPos)) match
595+
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
596+
origType.firstParamTypes.indices.toList.map(n =>
597+
bunchedParam
598+
.select(nme.primitive.arrayApply)
599+
.appliedTo(Literal(Constant(n))))
600+
case refs1 => refs1
601+
etaExpand(args ::: expandedRefs, anonFuns1, anonFun)
602+
603+
val unadapted = Closure(anonFun, lambdaBody)
604+
cpy.Block(unadapted)(unadapted.stats,
605+
adaptClosure(unadapted.expr.asInstanceOf[Closure]))
606+
}
607+
608+
//println(i"forward $ref with $args%, %, owner = $owner, $pt")
609+
etaExpand(expandArgs(args, owner)(using preErasureCtx), toAbstract, owner)
610+
//.showing(i"forward $ref with $args%, %, owner = $owner, $pt = $result")
611+
end forwarder
612+
537613
/** Eta expand given `tree` that has the given method type `mt`, so that
538614
* it conforms to erased result type `pt`.
539615
* To do this correctly, we have to look at the tree's original pre-erasure

tests/run/i13691.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import language.experimental.erasedDefinitions
2+
3+
erased class CanThrow[-E <: Exception]
4+
erased class Foo
5+
class Bar
6+
7+
object unsafeExceptions:
8+
given canThrowAny: CanThrow[Exception] = null
9+
10+
object test1:
11+
trait Decoder[+T]:
12+
def apply(): T
13+
14+
def deco: Decoder[CanThrow[Exception] ?=> Int] = new Decoder[CanThrow[Exception] ?=> Int]:
15+
def apply(): CanThrow[Exception] ?=> Int = 1
16+
17+
object test2:
18+
trait Decoder[+T]:
19+
def apply(): T
20+
21+
def deco: Decoder[(CanThrow[Exception], Foo) ?=> Int] = new Decoder[(CanThrow[Exception], Foo) ?=> Int]:
22+
def apply(): (CanThrow[Exception], Foo) ?=> Int = 1
23+
24+
object test3:
25+
trait Decoder[+T]:
26+
def apply(): T
27+
28+
def deco: Decoder[CanThrow[Exception] ?=> Foo ?=> Int] = new Decoder[CanThrow[Exception] ?=> Foo ?=> Int]:
29+
def apply(): CanThrow[Exception] ?=> Foo ?=> Int = 1
30+
31+
object test4:
32+
trait Decoder[+T]:
33+
def apply(): T
34+
35+
def deco: Decoder[CanThrow[Exception] ?=> Bar ?=> Int] = new Decoder[CanThrow[Exception] ?=> Bar ?=> Int]:
36+
def apply(): CanThrow[Exception] ?=> Bar ?=> Int = 1
37+
38+
object test5:
39+
trait Decoder[+T]:
40+
def apply(): T
41+
42+
def deco: Decoder[Bar ?=> CanThrow[Exception] ?=> Int] = new Decoder[Bar ?=> CanThrow[Exception] ?=> Int]:
43+
def apply(): Bar ?=> CanThrow[Exception] ?=> Int = 1
44+
45+
@main def Test(): Unit =
46+
import unsafeExceptions.canThrowAny
47+
given Foo = ???
48+
given Bar = Bar()
49+
test1.deco.apply().apply
50+
test2.deco.apply().apply
51+
test3.deco.apply().apply
52+
test4.deco.apply().apply
53+
test5.deco.apply().apply

0 commit comments

Comments
 (0)