Skip to content

Commit 877167e

Browse files
committed
Support polymorphic function values
A polymorphic function value can be written as: new PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R = body } This is erased to: new FunctionN { def apply(x_1: Object, ..., x_N: Object): Object = body } Getting everything to erase correctly was tricky, the current implementation is a bit messy currently.
1 parent 203b3a9 commit 877167e

File tree

6 files changed

+95
-4
lines changed

6 files changed

+95
-4
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class Compiler {
9999
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
100100
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
101101
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
102+
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
102103
new Mixin, // Expand trait fields and trait initializers
103104
new LazyVals, // Expand lazy vals
104105
new Memoize, // Add private fields to getters and setters

compiler/src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,24 @@ object TypeErasure {
184184
MethodType(Nil, defn.BoxedUnitType)
185185
else if (sym.isAnonymousFunction && einfo.paramInfos.length > MaxImplementedFunctionArity)
186186
MethodType(nme.ALLARGS :: Nil, JavaArrayType(defn.ObjectType) :: Nil, einfo.resultType)
187+
else if (sym.name == nme.apply && sym.owner.derivesFrom(defn.PolyFunctionClass)) {
188+
// The erasure of `apply` in subclasses of PolyFunction has to match
189+
// the erasure of FunctionN#apply, since after `ElimPolyFunction` we replace
190+
// a `PolyFunction` parent by a `FunctionN` parent.
191+
einfo.derivedLambdaType(
192+
paramInfos = einfo.paramInfos.map(_ => defn.ObjectType),
193+
resType = defn.ObjectType
194+
)
195+
}
187196
else
188197
einfo
189198
case einfo =>
190-
einfo
199+
// Erase the parameters of `apply` in subclasses of PolyFunction
200+
if (sym.is(TermParam) && sym.owner.name == nme.apply
201+
&& sym.owner.owner.derivesFrom(defn.PolyFunctionClass))
202+
defn.ObjectType
203+
else
204+
einfo
191205
}
192206
}
193207

compiler/src/dotty/tools/dotc/transform/ElimErasedValueType.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ class ElimErasedValueType extends MiniPhase with InfoTransformer {
8989
val info1 = site.memberInfo(sym1)
9090
val info2 = site.memberInfo(sym2)
9191
def isDefined(sym: Symbol) = sym.originDenotation.validFor.firstPhaseId <= ctx.phaseId
92-
if (isDefined(sym1) && isDefined(sym2) && !info1.matchesLoosely(info2))
92+
if (isDefined(sym1) && isDefined(sym2) && !info1.matchesLoosely(info2)
93+
&& !(sym1.name == nme.apply &&
94+
(sym1.owner.derivesFrom(defn.PolyFunctionClass)
95+
|| sym2.owner.derivesFrom(defn.PolyFunctionClass))))
9396
// The reason for the `isDefined` condition is that we need to exclude mixin forwarders
9497
// from the tests. For instance, in compileStdLib, compiling scala.immutable.SetProxy, line 29:
9598
// new AbstractSet[B] with SetProxy[B] { val self = newSelf }
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import ast.{Trees, tpd}
5+
import core._, core.Decorators._
6+
import MegaPhase._, Phases.Phase
7+
import Types._, Contexts._, Constants._, Names._, NameOps._, Flags._, DenotTransformers._
8+
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Scopes._, Denotations._
9+
import TypeErasure.ErasedValueType, ValueClasses._
10+
11+
/** This phase rewrite PolyFunction subclasses to FunctionN subclasses
12+
*
13+
* class Foo extends PolyFunction {
14+
* def apply(x_1: P_1, ..., x_N: P_N): R = rhs
15+
* }
16+
* becomes:
17+
* class Foo extends FunctionN {
18+
* def apply(x_1: P_1, ..., x_N: P_N): R = rhs
19+
* }
20+
*/
21+
class ElimPolyFunction extends MiniPhase with DenotTransformer {
22+
23+
import tpd._
24+
25+
override def phaseName: String = ElimPolyFunction.name
26+
27+
override def runsAfter = Set(Erasure.name)
28+
29+
override def changesParents: Boolean = true // Replaces PolyFunction by FunctionN
30+
31+
override def transform(ref: SingleDenotation)(implicit ctx: Context) = ref match {
32+
case ref: ClassDenotation if ref.symbol != defn.PolyFunctionClass && ref.derivesFrom(defn.PolyFunctionClass) =>
33+
val cinfo = ref.classInfo
34+
val newParent = functionTypeOfPoly(cinfo)
35+
val newParents = cinfo.classParents.map(parent =>
36+
if (parent.typeSymbol == defn.PolyFunctionClass)
37+
newParent
38+
else
39+
parent
40+
)
41+
ref.copySymDenotation(info = cinfo.derivedClassInfo(classParents = newParents))
42+
case _ =>
43+
ref
44+
}
45+
46+
def functionTypeOfPoly(cinfo: ClassInfo)(implicit ctx: Context): Type = {
47+
val applyMeth = cinfo.decls.lookup(nme.apply).info
48+
val arity = applyMeth.paramNamess.head.length
49+
defn.FunctionType(arity)
50+
}
51+
52+
override def transformTemplate(tree: Template)(implicit ctx: Context): Tree = {
53+
val newParents = tree.parents.mapconserve(parent =>
54+
if (parent.tpe.typeSymbol == defn.PolyFunctionClass) {
55+
val cinfo = tree.symbol.owner.asClass.classInfo
56+
tpd.TypeTree(functionTypeOfPoly(cinfo))
57+
}
58+
else
59+
parent
60+
)
61+
cpy.Template(tree)(parents = newParents)
62+
}
63+
}
64+
65+
object ElimPolyFunction {
66+
val name = "elimPolyFunction"
67+
}
68+

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ trait TypeAssigner {
5353
parentType.findMember(decl.name, cls.thisType, excluded = Private)
5454
.suchThat(decl.matches(_))
5555
val inheritedInfo = inherited.info
56-
if (inheritedInfo.exists && decl.info <:< inheritedInfo && !(inheritedInfo <:< decl.info)) {
56+
val isPolyFunctionApply = decl.name == nme.apply && (parent <:< defn.PolyFunctionType)
57+
if (isPolyFunctionApply || inheritedInfo.exists && decl.info <:< inheritedInfo && !(inheritedInfo <:< decl.info)) {
5758
val r = RefinedType(parent, decl.name, decl.info)
5859
typr.println(i"add ref $parent $decl --> " + r)
5960
r

tests/run/polymorphic-functions.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ object Test {
44
}
55

66
def main(args: Array[String]): Unit = {
7-
//test1(...)
7+
val fun = new PolyFunction {
8+
def apply[T <: AnyVal](x: List[T]): List[(T, T)] = x.map(e => (e, e))
9+
}
10+
11+
assert(test1(fun) == List((1, 1), (2, 2), (3, 3)))
812
}
913
}

0 commit comments

Comments
 (0)