Skip to content

Commit d1ec2a1

Browse files
committed
Add syntactic sugar for polymorphic function values
Desugar the value [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1 parent 877167e commit d1ec2a1

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,15 +1090,28 @@ object desugar {
10901090
}
10911091

10921092
val desugared = tree match {
1093-
case PolyFunction(targs, body) if (ctx.mode.is(Mode.Type)) =>
1094-
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1095-
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1096-
val Function(vargs, resType) = body
1093+
case PolyFunction(targs, body) =>
1094+
val Function(vargs, res) = body
1095+
val polyFunctionTpt = ref(defn.PolyFunctionType)
10971096
val applyTParams = targs.asInstanceOf[List[TypeDef]]
1098-
val applyVParams = vargs.zipWithIndex.map({case (p, n) => makeSyntheticParameter(n + 1, p)})
1099-
RefinedTypeTree(ref(defn.PolyFunctionType), List(
1100-
DefDef(nme.apply, applyTParams, List(applyVParams), resType, EmptyTree)
1101-
))
1097+
if (ctx.mode.is(Mode.Type)) {
1098+
// Desugar [T_1, ..., T_M] -> (P_1, ..., P_N) => R
1099+
// Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
1100+
1101+
val applyVParams = vargs.zipWithIndex.map({case (p, n) => makeSyntheticParameter(n + 1, p)})
1102+
RefinedTypeTree(polyFunctionTpt, List(
1103+
DefDef(nme.apply, applyTParams, List(applyVParams), res, EmptyTree)
1104+
))
1105+
} else {
1106+
// Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1107+
// Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1108+
1109+
val applyVParams = vargs.asInstanceOf[List[ValDef]]
1110+
.map(varg => varg.withMods(varg.mods | SyntheticTermParam))
1111+
New(Template(emptyConstructor, List(polyFunctionTpt), EmptyValDef,
1112+
List(DefDef(nme.apply, applyTParams, List(applyVParams), TypeTree(), res))
1113+
))
1114+
}
11021115
case SymbolLit(str) =>
11031116
Literal(Constant(scala.Symbol(str)))
11041117
case Quote(expr) =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,11 @@ object Parsers {
11941194
atPos(in.skipToken()) { Return(if (isExprIntro) expr() else EmptyTree, EmptyTree) }
11951195
case FOR =>
11961196
forExpr()
1197+
case LBRACKET =>
1198+
val start = in.offset
1199+
val tparams = typeParamClause(ParamOwner.TypeParam)
1200+
assert(isIdent && in.name.toString == "->", "Expected `->`")
1201+
atPos(start, in.skipToken())(PolyFunction(tparams, expr()))
11971202
case _ =>
11981203
expr1Rest(postfixExpr(), location)
11991204
}

tests/run/polymorphic-functions.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ object Test {
44
}
55

66
def main(args: Array[String]): Unit = {
7-
val fun = new PolyFunction {
8-
def apply[T <: AnyVal](x: List[T]): List[(T, T)] = x.map(e => (e, e))
9-
}
7+
val fun = [T <: AnyVal] -> (x: List[T]) => x.map(e => (e, e))
108

119
assert(test1(fun) == List((1, 1), (2, 2), (3, 3)))
1210
}

0 commit comments

Comments
 (0)