Skip to content

Commit b41bebd

Browse files
committed
Fix eta expand at erasure to take erased CFTs into account
1 parent 3c20740 commit b41bebd

File tree

12 files changed

+154
-99
lines changed

12 files changed

+154
-99
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
121121
Closure(Nil, call, targetTpt))
122122
}
123123

124-
/** A closure whole anonymous function has the given method type */
124+
/** A closure whose anonymous function has the given method type */
125125
def Lambda(tpe: MethodType, rhsFn: List[Tree] => Tree)(using Context): Block = {
126-
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe)
126+
val meth = newAnonFun(ctx.owner, tpe)
127127
Closure(meth, tss => rhsFn(tss.head).changeOwner(ctx.owner, meth))
128128
}
129129

@@ -1104,6 +1104,21 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11041104
if (sym.exists) sym.defTree = tree
11051105
tree
11061106
}
1107+
1108+
def etaExpandCFT(using Context): Tree =
1109+
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
1110+
case defn.ContextFunctionType(argTypes, resType, isErased) =>
1111+
val anonFun = newAnonFun(
1112+
ctx.owner,
1113+
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
1114+
coord = ctx.owner.coord)
1115+
def lambdaBody(refss: List[List[Tree]]) =
1116+
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
1117+
using ctx.withOwner(anonFun))
1118+
Closure(anonFun, lambdaBody)
1119+
case _ =>
1120+
target
1121+
expand(tree, tree.tpe.widen)
11071122
}
11081123

11091124
inline val MapRecursionLimit = 10

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,10 @@ object Symbols {
712712
coord: Coord = NoCoord)(using Context): TermSymbol =
713713
newSymbol(cls, nme.CONSTRUCTOR, flags | Method, MethodType(paramNames, paramTypes, cls.typeRef), privateWithin, coord)
714714

715+
/** Create an anonymous function symbol */
716+
def newAnonFun(owner: Symbol, info: Type, coord: Coord = NoCoord)(using Context): TermSymbol =
717+
newSymbol(owner, nme.ANON_FUN, Synthetic | Method, info, coord = coord)
718+
715719
/** Create an empty default constructor symbol for given class `cls`. */
716720
def newDefaultConstructor(cls: ClassSymbol)(using Context): TermSymbol =
717721
newConstructor(cls, EmptyFlags, Nil, Nil)

compiler/src/dotty/tools/dotc/transform/AccessProxies.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ abstract class AccessProxies {
5151
forwardedArgss.nonEmpty && forwardedArgss.head.nonEmpty) // defensive conditions
5252
accessRef.becomes(forwardedArgss.head.head)
5353
else
54-
accessRef.appliedToTypeTrees(forwardedTpts).appliedToArgss(forwardedArgss)
54+
accessRef
55+
.appliedToTypeTrees(forwardedTpts)
56+
.appliedToArgss(forwardedArgss)
57+
.etaExpandCFT(using ctx.withOwner(accessor))
5558
rhs.withSpan(accessed.span)
5659
})
5760

compiler/src/dotty/tools/dotc/transform/Bridges.scala

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ import ast.untpd
99
import collection.{mutable, immutable}
1010
import util.Spans.Span
1111
import util.SrcPos
12+
import ContextFunctionResults.{contextResultCount, contextFunctionResultTypeAfter}
13+
import StdNames.nme
14+
import Constants.Constant
15+
import TypeErasure.transformInfo
16+
import Erasure.Boxing.adaptClosure
1217

1318
/** A helper class for generating bridge methods in class `root`. */
1419
class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
@@ -112,12 +117,52 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
112117
toBeRemoved += other
113118
}
114119

115-
def bridgeRhs(argss: List[List[Tree]]) = {
120+
val memberCount = contextResultCount(member)
121+
122+
/** Eta expand application `ref(args)` as needed.
123+
* To do this correctly, we have to look at the member's original pre-erasure
124+
* type and figure out which context function types in its result are
125+
* not yet instantiated.
126+
*/
127+
def etaExpand(ref: Tree, args: List[Tree])(using Context): Tree =
128+
def expand(args: List[Tree], tp: Type, n: Int)(using Context): Tree =
129+
if n <= 0 then
130+
assert(ctx.typer.isInstanceOf[Erasure.Typer])
131+
ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType)
132+
else
133+
val defn.ContextFunctionType(argTypes, resType, isErased) = tp: @unchecked
134+
val anonFun = newAnonFun(ctx.owner,
135+
MethodType(if isErased then Nil else argTypes, resType),
136+
coord = ctx.owner.coord)
137+
anonFun.info = transformInfo(anonFun, anonFun.info)
138+
139+
def lambdaBody(refss: List[List[Tree]]) =
140+
val refs :: Nil = refss: @unchecked
141+
val expandedRefs = refs.map(_.withSpan(ctx.owner.span.endPos)) match
142+
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
143+
argTypes.indices.toList.map(n =>
144+
bunchedParam
145+
.select(nme.primitive.arrayApply)
146+
.appliedTo(Literal(Constant(n))))
147+
case refs1 => refs1
148+
expand(args ::: expandedRefs, resType, n - 1)(using ctx.withOwner(anonFun))
149+
150+
val unadapted = Closure(anonFun, lambdaBody)
151+
cpy.Block(unadapted)(unadapted.stats,
152+
adaptClosure(unadapted.expr.asInstanceOf[Closure]))
153+
end expand
154+
155+
val otherCount = contextResultCount(other)
156+
val start = contextFunctionResultTypeAfter(member, otherCount)(using preErasureCtx)
157+
expand(args, start, memberCount - otherCount)(using ctx.withOwner(bridge))
158+
end etaExpand
159+
160+
def bridgeRhs(argss: List[List[Tree]]) =
116161
assert(argss.tail.isEmpty)
117162
val ref = This(root).select(member)
118-
if (member.info.isParameterless) ref // can happen if `member` is a module
119-
else Erasure.partialApply(ref, argss.head)
120-
}
163+
if member.info.isParameterless then ref // can happen if `member` is a module
164+
else if memberCount == 0 then ref.appliedToTermArgs(argss.head)
165+
else etaExpand(ref, argss.head)
121166

122167
bridges += DefDef(bridge, bridgeRhs(_).withSpan(bridge.span))
123168
}

compiler/src/dotty/tools/dotc/transform/ByNameClosures.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ class ByNameClosures extends TransformByNameApply with IdentityDenotTransformer
3030
// ExpanSAMs applied to partial functions creates methods that need
3131
// to be fully defined before converting. Test case is pos/i9391.scala.
3232

33-
override def mkByNameClosure(arg: Tree, argType: Type)(using Context): Tree = {
34-
val meth = newSymbol(
35-
ctx.owner, nme.ANON_FUN, Synthetic | Method, MethodType(Nil, Nil, argType))
33+
override def mkByNameClosure(arg: Tree, argType: Type)(using Context): Tree =
34+
val meth = newAnonFun(ctx.owner, MethodType(Nil, argType))
3635
Closure(meth, _ => arg.changeOwnerAfter(ctx.owner, meth, thisPhase)).withSpan(arg.span)
37-
}
3836
}
3937

4038
object ByNameClosures {

compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -99,33 +99,13 @@ object ContextFunctionResults:
9999
normalParamCount(sym.info)
100100
end totalParamCount
101101

102-
/** The rightmost context function type in the result type of `meth`
103-
* that represents `paramCount` curried, non-erased parameters that
104-
* are included in the `contextResultCount` of `meth`.
105-
* Example:
106-
*
107-
* Say we have `def m(x: A): B ?=> (C1, C2, C3) ?=> D ?=> E ?=> F`,
108-
* paramCount == 4, and the contextResultCount of `m` is 3.
109-
* Then we return the type `(C1, C2, C3) ?=> D ?=> E ?=> F`, since this
110-
* type covers the 4 rightmost parameters C1, C2, C3 and D before the
111-
* contextResultCount runs out at E ?=> F.
112-
* Erased parameters are ignored; they contribute nothing to the
113-
* parameter count.
114-
*/
115-
def contextFunctionResultTypeCovering(meth: Symbol, paramCount: Int)(using Context) =
116-
atPhase(erasurePhase) {
117-
// Recursive instances return pairs of context types and the
118-
// # of parameters they represent.
119-
def missingCR(tp: Type, crCount: Int): (Type, Int) =
120-
if crCount == 0 then (tp, 0)
121-
else
122-
val defn.ContextFunctionType(formals, resTpe, isErased) = tp: @unchecked
123-
val result @ (rt, nparams) = missingCR(resTpe, crCount - 1)
124-
assert(nparams <= paramCount)
125-
if nparams == paramCount || isErased then result
126-
else (tp, nparams + formals.length)
127-
missingCR(meth.info.finalResultType, contextResultCount(meth))._1
128-
}
102+
/** The `depth` levels nested context function type in the result type of `meth` */
103+
def contextFunctionResultTypeAfter(meth: Symbol, depth: Int)(using Context) =
104+
def recur(tp: Type, n: Int): Type =
105+
if n == 0 then tp
106+
else tp match
107+
case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1)
108+
recur(meth.info.finalResultType, depth)
129109

130110
/** Should selection `tree` be eliminated since it refers to an `apply`
131111
* node of a context function type whose parameters will end up being

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Erasure extends Phase with DenotTransformer {
5757
case _ => false
5858
}
5959
}
60+
6061
def erasedName =
6162
if ref.is(Flags.Method)
6263
&& contextResultsAreErased(ref.symbol)
@@ -383,8 +384,8 @@ object Erasure {
383384
case _: FunProto | AnyFunctionProto => tree
384385
case _ => tree.tpe.widen match
385386
case mt: MethodType if tree.isTerm =>
386-
if mt.paramInfos.isEmpty then adaptToType(tree.appliedToNone, pt)
387-
else etaExpand(tree, mt, pt)
387+
assert(mt.paramInfos.isEmpty)
388+
adaptToType(tree.appliedToNone, pt)
388389
case tpw =>
389390
if (pt.isInstanceOf[ProtoType] || tree.tpe <:< pt)
390391
tree
@@ -533,61 +534,6 @@ object Erasure {
533534
else
534535
tree
535536
end adaptClosure
536-
537-
/** Eta expand given `tree` that has the given method type `mt`, so that
538-
* it conforms to erased result type `pt`.
539-
* To do this correctly, we have to look at the tree's original pre-erasure
540-
* type and figure out which context function types in its result are
541-
* not yet instantiated.
542-
*/
543-
def etaExpand(tree: Tree, mt: MethodType, pt: Type)(using Context): Tree =
544-
report.log(i"eta expanding $tree")
545-
val defs = new mutable.ListBuffer[Tree]
546-
val tree1 = LiftErased.liftApp(defs, tree)
547-
val xmt = if tree.isInstanceOf[Apply] then mt else expandedMethodType(mt, tree)
548-
val targetLength = xmt.paramInfos.length
549-
val origOwner = ctx.owner
550-
551-
// The original type from which closures should be constructed
552-
val origType = contextFunctionResultTypeCovering(tree.symbol, targetLength)
553-
554-
def abstracted(args: List[Tree], tp: Type, pt: Type)(using Context): Tree =
555-
if args.length < targetLength then
556-
try
557-
val defn.ContextFunctionType(argTpes, resTpe, isErased) = tp: @unchecked
558-
if isErased then abstracted(args, resTpe, pt)
559-
else
560-
val anonFun = newSymbol(
561-
ctx.owner, nme.ANON_FUN, Flags.Synthetic | Flags.Method,
562-
MethodType(argTpes, resTpe), coord = tree.span.endPos)
563-
anonFun.info = transformInfo(anonFun, anonFun.info)
564-
def lambdaBody(refss: List[List[Tree]]) =
565-
val refs :: Nil = refss: @unchecked
566-
val expandedRefs = refs.map(_.withSpan(tree.span.endPos)) match
567-
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
568-
argTpes.indices.toList.map(n =>
569-
bunchedParam
570-
.select(nme.primitive.arrayApply)
571-
.appliedTo(Literal(Constant(n))))
572-
case refs1 => refs1
573-
abstracted(args ::: expandedRefs, resTpe, anonFun.info.finalResultType)(
574-
using ctx.withOwner(anonFun))
575-
576-
val unadapted = Closure(anonFun, lambdaBody)
577-
cpy.Block(unadapted)(unadapted.stats, adaptClosure(unadapted.expr.asInstanceOf[Closure]))
578-
catch case ex: MatchError =>
579-
println(i"error while abstracting tree = $tree | mt = $mt | args = $args%, % | tp = $tp | pt = $pt")
580-
throw ex
581-
else
582-
assert(args.length == targetLength, i"wrong # args tree = $tree | args = $args%, % | mt = $mt | tree type = ${tree.tpe}")
583-
val app = untpd.cpy.Apply(tree1)(tree1, args)
584-
assert(ctx.typer.isInstanceOf[Erasure.Typer])
585-
ctx.typer.typed(app, pt)
586-
.changeOwnerAfter(origOwner, ctx.owner, erasurePhase.asInstanceOf[Erasure])
587-
588-
seq(defs.toList, abstracted(Nil, origType, pt))
589-
end etaExpand
590-
591537
end Boxing
592538

593539
class Typer(erasurePhase: DenotTransformer) extends typer.ReTyper with NoChecking {

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
7070
ref(defn.NoneModule))
7171
}
7272
val tpe = MethodType(List(nme.s))(_ => List(tp1), mth => defn.OptionClass.typeRef.appliedTo(mth.newParamRef(0) & tp2))
73-
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe, coord = span)
73+
val meth = newAnonFun(ctx.owner, tpe, coord = span)
7474
val typeTestType = defn.TypeTestClass.typeRef.appliedTo(List(tp1, tp2))
7575
Closure(meth, tss => body(tss.head).changeOwner(ctx.owner, meth), targetType = typeTestType).withSpan(span)
7676
case _ =>

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ object QuoteMatcher {
212212
}
213213
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
214214
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
215-
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, methTpe)
215+
val meth = newAnonFun(ctx.owner, methTpe)
216216
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
217217
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
218218
val body = new TreeMap {

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
385385
case t => t
386386
}
387387
val closureTpe = Types.MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
388-
val closureMethod = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
388+
val closureMethod = dotc.core.Symbols.newAnonFun(owner, closureTpe)
389389
tpd.Closure(closureMethod, tss => new tpd.TreeOps(self).appliedToTermArgs(tss.head).etaExpand(closureMethod))
390390
case _ => self
391391
}
@@ -793,7 +793,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
793793

794794
object Lambda extends LambdaModule:
795795
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
796-
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
796+
val meth = dotc.core.Symbols.newAnonFun(owner, tpe)
797797
tpd.Closure(meth, tss => xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth))
798798

799799
def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {

tests/run/i13691.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import language.experimental.erasedDefinitions
2+
3+
erased class CanThrow[-E <: Exception]
4+
erased class Foo
5+
class Bar
6+
7+
object unsafeExceptions:
8+
given canThrowAny: CanThrow[Exception] = null
9+
10+
object test1:
11+
trait Decoder[+T]:
12+
def apply(): T
13+
14+
def deco: Decoder[CanThrow[Exception] ?=> Int] = new Decoder[CanThrow[Exception] ?=> Int]:
15+
def apply(): CanThrow[Exception] ?=> Int = 1
16+
17+
object test2:
18+
trait Decoder[+T]:
19+
def apply(): T
20+
21+
def deco: Decoder[(CanThrow[Exception], Foo) ?=> Int] = new Decoder[(CanThrow[Exception], Foo) ?=> Int]:
22+
def apply(): (CanThrow[Exception], Foo) ?=> Int = 1
23+
24+
object test3:
25+
trait Decoder[+T]:
26+
def apply(): T
27+
28+
def deco: Decoder[CanThrow[Exception] ?=> Foo ?=> Int] = new Decoder[CanThrow[Exception] ?=> Foo ?=> Int]:
29+
def apply(): CanThrow[Exception] ?=> Foo ?=> Int = 1
30+
31+
object test4:
32+
trait Decoder[+T]:
33+
def apply(): T
34+
35+
def deco: Decoder[CanThrow[Exception] ?=> Bar ?=> Int] = new Decoder[CanThrow[Exception] ?=> Bar ?=> Int]:
36+
def apply(): CanThrow[Exception] ?=> Bar ?=> Int = 1
37+
38+
object test5:
39+
trait Decoder[+T]:
40+
def apply(): T
41+
42+
def deco: Decoder[Bar ?=> CanThrow[Exception] ?=> Int] = new Decoder[Bar ?=> CanThrow[Exception] ?=> Int]:
43+
def apply(): Bar ?=> CanThrow[Exception] ?=> Int = 1
44+
45+
@main def Test(): Unit =
46+
import unsafeExceptions.canThrowAny
47+
given Foo = ???
48+
given Bar = Bar()
49+
test1.deco.apply().apply
50+
test2.deco.apply().apply
51+
test3.deco.apply().apply
52+
test4.deco.apply().apply
53+
test5.deco.apply().apply

tests/run/i13961a.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import language.experimental.saferExceptions
2+
3+
trait Decoder[+T]:
4+
def apply(): T
5+
6+
given Decoder[Int throws Exception] = new Decoder[Int throws Exception]:
7+
def apply(): Int throws Exception = 1
8+
9+
@main def Test(): Unit =
10+
import unsafeExceptions.canThrowAny
11+
summon[Decoder[Int throws Exception]]()

0 commit comments

Comments
 (0)