Skip to content

Commit 182331b

Browse files
Support polymorphic functions with erased parameters (#18293)
This adds support for ```scala [T1, ..., Tn] => ([erased] x1: X1, ..., [erased] xm: Xm) => r: R ``` Polymorphic function types with erased parameters are represented as using a refinement on `PolyFunction`. `ErasedFunction` is not needed. ```scala PolyFunction { def apply[[T1, ..., Tn]]([given] [erased] x1: X1, ..., [erased] xm: Xm): R } ```
2 parents 97677cc + 5625107 commit 182331b

10 files changed

+95
-56
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,19 +1100,21 @@ object desugar {
11001100
*/
11011101
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
11021102
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
1103-
val funFlags = fun match
1103+
val paramFlags = fun match
11041104
case fun: FunctionWithMods =>
1105-
fun.mods.flags
1106-
case _ => EmptyFlags
1105+
// TODO: make use of this in the desugaring when pureFuns is enabled.
1106+
// val isImpure = funFlags.is(Impure)
11071107

1108-
// TODO: make use of this in the desugaring when pureFuns is enabled.
1109-
// val isImpure = funFlags.is(Impure)
1108+
// Function flags to be propagated to each parameter in the desugared method type.
1109+
val givenFlag = fun.mods.flags.toTermFlags & Given
1110+
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
1111+
case _ =>
1112+
vparamTypes.map(_ => EmptyFlags)
11101113

1111-
// Function flags to be propagated to each parameter in the desugared method type.
1112-
val paramFlags = funFlags.toTermFlags & Given
1113-
val vparams = vparamTypes.zipWithIndex.map:
1114-
case (p: ValDef, _) => p.withAddedFlags(paramFlags)
1115-
case (p, n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1114+
val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
1115+
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
1116+
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
1117+
}.toList
11161118

11171119
RefinedTypeTree(ref(defn.PolyFunctionType), List(
11181120
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,23 +1426,6 @@ object Parsers {
14261426
case _ => None
14271427
}
14281428

1429-
private def checkFunctionNotErased(f: Function, context: String) =
1430-
def fail(span: Span) =
1431-
syntaxError(em"Implementation restriction: erased parameters are not supported in $context", span)
1432-
// erased parameter in type
1433-
val hasErasedParam = f match
1434-
case f: FunctionWithMods => f.hasErasedParams
1435-
case _ => false
1436-
if hasErasedParam then
1437-
fail(f.span)
1438-
// erased parameter in term
1439-
val hasErasedMods = f.args.collectFirst {
1440-
case v: ValDef if v.mods.is(Flags.Erased) => v
1441-
}
1442-
hasErasedMods match
1443-
case Some(param) => fail(param.span)
1444-
case _ =>
1445-
14461429
/** CaptureRef ::= ident | `this`
14471430
*/
14481431
def captureRef(): Tree =
@@ -1592,7 +1575,6 @@ object Parsers {
15921575
atSpan(start, arrowOffset) {
15931576
getFunction(body) match {
15941577
case Some(f) =>
1595-
checkFunctionNotErased(f, "poly function")
15961578
PolyFunction(tparams, body)
15971579
case None =>
15981580
syntaxError(em"Implementation restriction: polymorphic function types must have a value parameter", arrowOffset)
@@ -2159,7 +2141,6 @@ object Parsers {
21592141
atSpan(start, arrowOffset) {
21602142
getFunction(body) match
21612143
case Some(f) =>
2162-
checkFunctionNotErased(f, "poly function")
21632144
PolyFunction(tparams, f)
21642145
case None =>
21652146
syntaxError(em"Implementation restriction: polymorphic function literals must have a value parameter", arrowOffset)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2604,17 +2604,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
26042604
report.error(em"Cannot return repeated parameter type ${sym.info.finalResultType}", sym.srcPos)
26052605
if !sym.is(Module) && !sym.isConstructor && sym.info.finalResultType.isErasedClass then
26062606
sym.setFlag(Erased)
2607-
if
2608-
sym.info.isInstanceOf[PolyType] &&
2609-
((sym.name eq nme.ANON_FUN) ||
2610-
(sym.name eq nme.apply) && sym.owner.derivesFrom(defn.PolyFunctionClass))
2611-
then
2612-
mdef match
2613-
case DefDef(_, _ :: vparams :: Nil, _, _) =>
2614-
vparams.foreach: vparam =>
2615-
if vparam.symbol.is(Erased) then
2616-
report.error(em"Implementation restriction: erased classes are not allowed in a poly function definition", vparam.srcPos)
2617-
case _ =>
26182607

26192608
def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
26202609
val TypeDef(name, rhs) = tdef

tests/neg-custom-args/erased/poly-functions.scala

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:3:28 ---------------------------------
2+
3 |def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error
3+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
4+
| Found: [T] => (erased t: T) => Unit
5+
| Required: [T] => (x$1: T) => Unit
6+
|
7+
| longer explanation available when compiling with `-explain`
8+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:4:37 ---------------------------------
9+
4 |def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error
10+
| ^^^^^^^^^^^^^^^^^^^
11+
| Found: [T] => (t: T) => Unit
12+
| Required: [T] => (erased x$1: T) => Unit
13+
|
14+
| longer explanation available when compiling with `-explain`
15+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:6:36 ---------------------------------
16+
6 |def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error
17+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
18+
| Found: [T, U] => (t: T, erased u: U) => Unit
19+
| Required: [T, U] => (x$1: T, x$2: U) => Unit
20+
|
21+
| longer explanation available when compiling with `-explain`
22+
-- [E007] Type Mismatch Error: tests/neg/polymorphic-erased-functions-types.scala:7:43 ---------------------------------
23+
7 |def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error
24+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
25+
| Found: [T, U] => (t: T, u: U) => Unit
26+
| Required: [T, U] => (x$1: T, erased x$2: U) => Unit
27+
|
28+
| longer explanation available when compiling with `-explain`
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import language.experimental.erasedDefinitions
2+
3+
def t1a: [T] => T => Unit = [T] => (erased t: T) => () // error
4+
def t1b: [T] => (erased T) => Unit = [T] => (t: T) => () // error
5+
6+
def t2a: [T, U] => (T, U) => Unit = [T, U] => (t: T, erased u: U) => () // error
7+
def t2b: [T, U] => (T, erased U) => Unit = [T, U] => (t: T, u: U) => () // error
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Error: tests/neg/polymorphic-erased-functions-used.scala:3:33 -------------------------------------------------------
2+
3 |def t1 = [T] => (erased t: T) => t // error
3+
| ^
4+
| parameter t is declared as `erased`, but is in fact used
5+
-- Error: tests/neg/polymorphic-erased-functions-used.scala:4:42 -------------------------------------------------------
6+
4 |def t2 = [T, U] => (t: T, erased u: U) => u // error
7+
| ^
8+
| parameter u is declared as `erased`, but is in fact used
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import language.experimental.erasedDefinitions
2+
3+
def t1 = [T] => (erased t: T) => t // error
4+
def t2 = [T, U] => (t: T, erased u: U) => u // error

tests/pos/poly-erased-functions.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import language.experimental.erasedDefinitions
2+
3+
object Test:
4+
type T1 = [X] => (erased x: X, y: Int) => Int
5+
type T2 = [X] => (x: X, erased y: Int) => X
6+
7+
val t1 = [X] => (erased x: X, y: Int) => y
8+
val t2 = [X] => (x: X, erased y: Int) => x
9+
10+
erased class A
11+
12+
type T3 = [X] => (x: A, y: X) => X
13+
14+
val t3 = [X] => (x: A, y: X) => y
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import language.experimental.erasedDefinitions
2+
3+
object Test extends App {
4+
5+
// Types
6+
type F1 = [T] => (erased T) => Int
7+
type F2 = [T, U] => (T, erased U) => T
8+
9+
// Terms
10+
val t1 = [T] => (erased t: T) => 3
11+
assert(t1(List(1, 2, 3)) == 3)
12+
val t1a: F1 = t1
13+
val t1b: F1 = [T] => (erased t) => 3
14+
assert(t1b(List(1, 2, 3)) == 3)
15+
16+
val t2 = [T, U] => (t: T, erased u: U) => t
17+
assert(t2(1, "abc") == 1)
18+
val t2a: F2 = t2
19+
val t2b: F2 = [T, U] => (t, erased u) => t
20+
assert(t2b(1, "abc") == 1)
21+
22+
}

0 commit comments

Comments
 (0)