Skip to content

Commit f836589

Browse files
authored
Merge pull request #13736 from dotty-staging/fix-13691-v2
Fix erased context function types, 2nd attempt
2 parents aa25df2 + 7e6e9ec commit f836589

17 files changed

+194
-133
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
121121
Closure(Nil, call, targetTpt))
122122
}
123123

124-
/** A closure whole anonymous function has the given method type */
124+
/** A closure whose anonymous function has the given method type */
125125
def Lambda(tpe: MethodType, rhsFn: List[Tree] => Tree)(using Context): Block = {
126-
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe)
126+
val meth = newAnonFun(ctx.owner, tpe)
127127
Closure(meth, tss => rhsFn(tss.head).changeOwner(ctx.owner, meth))
128128
}
129129

@@ -1104,6 +1104,21 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
11041104
if (sym.exists) sym.defTree = tree
11051105
tree
11061106
}
1107+
1108+
def etaExpandCFT(using Context): Tree =
1109+
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
1110+
case defn.ContextFunctionType(argTypes, resType, isErased) =>
1111+
val anonFun = newAnonFun(
1112+
ctx.owner,
1113+
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
1114+
coord = ctx.owner.coord)
1115+
def lambdaBody(refss: List[List[Tree]]) =
1116+
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
1117+
using ctx.withOwner(anonFun))
1118+
Closure(anonFun, lambdaBody)
1119+
case _ =>
1120+
target
1121+
expand(tree, tree.tpe.widen)
11071122
}
11081123

11091124
inline val MapRecursionLimit = 10

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,23 +1353,6 @@ class Definitions {
13531353
def isBoxedUnitClass(cls: Symbol): Boolean =
13541354
cls.isClass && (cls.owner eq ScalaRuntimePackageClass) && cls.name == tpnme.BoxedUnit
13551355

1356-
/** Returns the erased class of the function class `cls`
1357-
* - FunctionN for N > 22 becomes FunctionXXL
1358-
* - FunctionN for 22 > N >= 0 remains as FunctionN
1359-
* - ContextFunctionN for N > 22 becomes FunctionXXL
1360-
* - ContextFunctionN for N <= 22 becomes FunctionN
1361-
* - ErasedFunctionN becomes Function0
1362-
* - ImplicitErasedFunctionN becomes Function0
1363-
* - anything else becomes a NoSymbol
1364-
*/
1365-
def erasedFunctionClass(cls: Symbol): Symbol = {
1366-
val arity = scalaClassName(cls).functionArity
1367-
if (cls.name.isErasedFunction) FunctionClass(0)
1368-
else if (arity > 22) FunctionXXLClass
1369-
else if (arity >= 0) FunctionClass(arity)
1370-
else NoSymbol
1371-
}
1372-
13731356
/** Returns the erased type of the function class `cls`
13741357
* - FunctionN for N > 22 becomes FunctionXXL
13751358
* - FunctionN for 22 > N >= 0 remains as FunctionN
@@ -1379,13 +1362,12 @@ class Definitions {
13791362
* - ImplicitErasedFunctionN becomes Function0
13801363
* - anything else becomes a NoType
13811364
*/
1382-
def erasedFunctionType(cls: Symbol): Type = {
1365+
def functionTypeErasure(cls: Symbol): Type =
13831366
val arity = scalaClassName(cls).functionArity
1384-
if (cls.name.isErasedFunction) FunctionType(0)
1385-
else if (arity > 22) FunctionXXLClass.typeRef
1386-
else if (arity >= 0) FunctionType(arity)
1367+
if cls.name.isErasedFunction then FunctionType(0)
1368+
else if arity > 22 then FunctionXXLClass.typeRef
1369+
else if arity >= 0 then FunctionType(arity)
13871370
else NoType
1388-
}
13891371

13901372
val predefClassNames: Set[Name] =
13911373
Set("Predef$", "DeprecatedPredef", "LowPriorityImplicits").map(_.toTypeName.unmangleClassName)

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ object NameKinds {
365365
val ExtMethName: SuffixNameKind = new SuffixNameKind(EXTMETH, "$extension")
366366
val ParamAccessorName: SuffixNameKind = new SuffixNameKind(PARAMACC, "$accessor")
367367
val ModuleClassName: SuffixNameKind = new SuffixNameKind(OBJECTCLASS, "$", optInfoString = "ModuleClass")
368-
val ImplMethName: SuffixNameKind = new SuffixNameKind(IMPLMETH, "$")
368+
val DirectMethName: SuffixNameKind = new SuffixNameKind(DIRECT, "$direct")
369369
val AdaptedClosureName: SuffixNameKind = new SuffixNameKind(ADAPTEDCLOSURE, "$adapted") { override def definesNewName = true }
370370
val SyntheticSetterName: SuffixNameKind = new SuffixNameKind(SETTER, "_$eq")
371371

compiler/src/dotty/tools/dotc/core/NameTags.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ object NameTags extends TastyFormat.NameTags {
2424

2525
final val ADAPTEDCLOSURE = 31 // Used in Erasure to adapt closures over primitive types.
2626

27-
final val IMPLMETH = 32 // Used to define methods in implementation classes
28-
// (can probably be removed).
27+
final val DIRECT = 32 // Used to define implementations of methods with
28+
// erased context function results that can override some
29+
// other method.
2930

3031
final val PARAMACC = 33 // Used for a private parameter alias
3132

@@ -48,7 +49,7 @@ object NameTags extends TastyFormat.NameTags {
4849
case INITIALIZER => "INITIALIZER"
4950
case FIELD => "FIELD"
5051
case EXTMETH => "EXTMETH"
51-
case IMPLMETH => "IMPLMETH"
52+
case DIRECT => "DIRECT"
5253
case PARAMACC => "PARAMACC"
5354
case ADAPTEDCLOSURE => "ADAPTEDCLOSURE"
5455
case OBJECTCLASS => "OBJECTCLASS"

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,10 @@ object Symbols {
712712
coord: Coord = NoCoord)(using Context): TermSymbol =
713713
newSymbol(cls, nme.CONSTRUCTOR, flags | Method, MethodType(paramNames, paramTypes, cls.typeRef), privateWithin, coord)
714714

715+
/** Create an anonymous function symbol */
716+
def newAnonFun(owner: Symbol, info: Type, coord: Coord = NoCoord)(using Context): TermSymbol =
717+
newSymbol(owner, nme.ANON_FUN, Synthetic | Method, info, coord = coord)
718+
715719
/** Create an empty default constructor symbol for given class `cls`. */
716720
def newDefaultConstructor(cls: ClassSymbol)(using Context): TermSymbol =
717721
newConstructor(cls, EmptyFlags, Nil, Nil)

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
581581
val sym = tp.symbol
582582
if (!sym.isClass) this(tp.translucentSuperType)
583583
else if (semiEraseVCs && isDerivedValueClass(sym)) eraseDerivedValueClass(tp)
584-
else if (defn.isSyntheticFunctionClass(sym)) defn.erasedFunctionType(sym)
584+
else if (defn.isSyntheticFunctionClass(sym)) defn.functionTypeErasure(sym)
585585
else eraseNormalClassRef(tp)
586586
case tp: AppliedType =>
587587
val tycon = tp.tycon
@@ -791,7 +791,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
791791
if (erasedVCRef.exists) return sigName(erasedVCRef)
792792
}
793793
if (defn.isSyntheticFunctionClass(sym))
794-
sigName(defn.erasedFunctionType(sym))
794+
sigName(defn.functionTypeErasure(sym))
795795
else
796796
val cls = normalizeClass(sym.asClass)
797797
val fullName =

compiler/src/dotty/tools/dotc/transform/AccessProxies.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ abstract class AccessProxies {
5151
forwardedArgss.nonEmpty && forwardedArgss.head.nonEmpty) // defensive conditions
5252
accessRef.becomes(forwardedArgss.head.head)
5353
else
54-
accessRef.appliedToTypeTrees(forwardedTpts).appliedToArgss(forwardedArgss)
54+
accessRef
55+
.appliedToTypeTrees(forwardedTpts)
56+
.appliedToArgss(forwardedArgss)
57+
.etaExpandCFT(using ctx.withOwner(accessor))
5558
rhs.withSpan(accessed.span)
5659
})
5760

compiler/src/dotty/tools/dotc/transform/Bridges.scala

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ import ast.untpd
99
import collection.{mutable, immutable}
1010
import util.Spans.Span
1111
import util.SrcPos
12+
import ContextFunctionResults.{contextResultCount, contextFunctionResultTypeAfter}
13+
import StdNames.nme
14+
import Constants.Constant
15+
import TypeErasure.transformInfo
16+
import Erasure.Boxing.adaptClosure
1217

1318
/** A helper class for generating bridge methods in class `root`. */
1419
class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
@@ -112,12 +117,52 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
112117
toBeRemoved += other
113118
}
114119

115-
def bridgeRhs(argss: List[List[Tree]]) = {
120+
val memberCount = contextResultCount(member)
121+
122+
/** Eta expand application `ref(args)` as needed.
123+
* To do this correctly, we have to look at the member's original pre-erasure
124+
* type and figure out which context function types in its result are
125+
* not yet instantiated.
126+
*/
127+
def etaExpand(ref: Tree, args: List[Tree])(using Context): Tree =
128+
def expand(args: List[Tree], tp: Type, n: Int)(using Context): Tree =
129+
if n <= 0 then
130+
assert(ctx.typer.isInstanceOf[Erasure.Typer])
131+
ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType)
132+
else
133+
val defn.ContextFunctionType(argTypes, resType, isErased) = tp: @unchecked
134+
val anonFun = newAnonFun(ctx.owner,
135+
MethodType(if isErased then Nil else argTypes, resType),
136+
coord = ctx.owner.coord)
137+
anonFun.info = transformInfo(anonFun, anonFun.info)
138+
139+
def lambdaBody(refss: List[List[Tree]]) =
140+
val refs :: Nil = refss: @unchecked
141+
val expandedRefs = refs.map(_.withSpan(ctx.owner.span.endPos)) match
142+
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
143+
argTypes.indices.toList.map(n =>
144+
bunchedParam
145+
.select(nme.primitive.arrayApply)
146+
.appliedTo(Literal(Constant(n))))
147+
case refs1 => refs1
148+
expand(args ::: expandedRefs, resType, n - 1)(using ctx.withOwner(anonFun))
149+
150+
val unadapted = Closure(anonFun, lambdaBody)
151+
cpy.Block(unadapted)(unadapted.stats,
152+
adaptClosure(unadapted.expr.asInstanceOf[Closure]))
153+
end expand
154+
155+
val otherCount = contextResultCount(other)
156+
val start = contextFunctionResultTypeAfter(member, otherCount)(using preErasureCtx)
157+
expand(args, start, memberCount - otherCount)(using ctx.withOwner(bridge))
158+
end etaExpand
159+
160+
def bridgeRhs(argss: List[List[Tree]]) =
116161
assert(argss.tail.isEmpty)
117162
val ref = This(root).select(member)
118-
if (member.info.isParameterless) ref // can happen if `member` is a module
119-
else Erasure.partialApply(ref, argss.head)
120-
}
163+
if member.info.isParameterless then ref // can happen if `member` is a module
164+
else if memberCount == 0 then ref.appliedToTermArgs(argss.head)
165+
else etaExpand(ref, argss.head)
121166

122167
bridges += DefDef(bridge, bridgeRhs(_).withSpan(bridge.span))
123168
}

compiler/src/dotty/tools/dotc/transform/ByNameClosures.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ class ByNameClosures extends TransformByNameApply with IdentityDenotTransformer
3030
// ExpanSAMs applied to partial functions creates methods that need
3131
// to be fully defined before converting. Test case is pos/i9391.scala.
3232

33-
override def mkByNameClosure(arg: Tree, argType: Type)(using Context): Tree = {
34-
val meth = newSymbol(
35-
ctx.owner, nme.ANON_FUN, Synthetic | Method, MethodType(Nil, Nil, argType))
33+
override def mkByNameClosure(arg: Tree, argType: Type)(using Context): Tree =
34+
val meth = newAnonFun(ctx.owner, MethodType(Nil, argType))
3635
Closure(meth, _ => arg.changeOwnerAfter(ctx.owner, meth, thisPhase)).withSpan(arg.span)
37-
}
3836
}
3937

4038
object ByNameClosures {

compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@ import StdNames.nme
88
import ast.untpd
99
import ast.tpd._
1010
import config.Config
11+
import Decorators.*
1112

1213
object ContextFunctionResults:
1314

1415
/** Annotate methods that have context function result types directly matched by context
1516
* closures on their right-hand side. Parameters to such closures will be integrated
1617
* as additional method parameters in erasure.
18+
*
19+
* A @ContextResultCount(n) annotation means that the method's result type
20+
* consists of a string of `n` nested context closures.
1721
*/
1822
def annotateContextResults(mdef: DefDef)(using Context): Unit =
1923
def contextResultCount(rhs: Tree, tp: Type): Int = tp match
@@ -50,6 +54,15 @@ object ContextFunctionResults:
5054
crCount
5155
case none => 0
5256

57+
/** True iff `ContextResultCount` is not zero and all context functions in the result
58+
* type are erased.
59+
*/
60+
def contextResultsAreErased(sym: Symbol)(using Context): Boolean =
61+
def allErased(tp: Type): Boolean = tp.dealias match
62+
case defn.ContextFunctionType(_, resTpe, isErased) => isErased && allErased(resTpe)
63+
case _ => true
64+
contextResultCount(sym) > 0 && allErased(sym.info.finalResultType)
65+
5366
/** Turn the first `crCount` context function types in the result type of `tp`
5467
* into the curried method types.
5568
*/
@@ -86,33 +99,13 @@ object ContextFunctionResults:
8699
normalParamCount(sym.info)
87100
end totalParamCount
88101

89-
/** The rightmost context function type in the result type of `meth`
90-
* that represents `paramCount` curried, non-erased parameters that
91-
* are included in the `contextResultCount` of `meth`.
92-
* Example:
93-
*
94-
* Say we have `def m(x: A): B ?=> (C1, C2, C3) ?=> D ?=> E ?=> F`,
95-
* paramCount == 4, and the contextResultCount of `m` is 3.
96-
* Then we return the type `(C1, C2, C3) ?=> D ?=> E ?=> F`, since this
97-
* type covers the 4 rightmost parameters C1, C2, C3 and D before the
98-
* contextResultCount runs out at E ?=> F.
99-
* Erased parameters are ignored; they contribute nothing to the
100-
* parameter count.
101-
*/
102-
def contextFunctionResultTypeCovering(meth: Symbol, paramCount: Int)(using Context) =
103-
atPhase(erasurePhase) {
104-
// Recursive instances return pairs of context types and the
105-
// # of parameters they represent.
106-
def missingCR(tp: Type, crCount: Int): (Type, Int) =
107-
if crCount == 0 then (tp, 0)
108-
else
109-
val defn.ContextFunctionType(formals, resTpe, isErased) = tp: @unchecked
110-
val result @ (rt, nparams) = missingCR(resTpe, crCount - 1)
111-
assert(nparams <= paramCount)
112-
if nparams == paramCount || isErased then result
113-
else (tp, nparams + formals.length)
114-
missingCR(meth.info.finalResultType, contextResultCount(meth))._1
115-
}
102+
/** The `depth` levels nested context function type in the result type of `meth` */
103+
def contextFunctionResultTypeAfter(meth: Symbol, depth: Int)(using Context) =
104+
def recur(tp: Type, n: Int): Type =
105+
if n == 0 then tp
106+
else tp match
107+
case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1)
108+
recur(meth.info.finalResultType, depth)
116109

117110
/** Should selection `tree` be eliminated since it refers to an `apply`
118111
* node of a context function type whose parameters will end up being

0 commit comments

Comments
 (0)