Skip to content

Commit 975739a

Browse files
committed
Auto-uncurry n-ary functions.
Implements SIP #897.
1 parent 945334c commit 975739a

File tree

5 files changed

+87
-19
lines changed

5 files changed

+87
-19
lines changed

src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,25 @@ object desugar {
580580
Function(params, Match(selector, cases))
581581
}
582582

583+
/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
584+
*
585+
* x$1 => {
586+
* val p1 = x$1._1
587+
* ...
588+
* val pn = x$1._n
589+
* body
590+
* }
591+
*/
592+
def makeUnaryCaseLambda(params: List[ValDef], body: Tree)(implicit ctx: Context): Tree = {
593+
val param = makeSyntheticParameter()
594+
def selector(n: Int) = Select(refOfDef(param), nme.selectorName(n))
595+
val vdefs =
596+
params.zipWithIndex.map{
597+
case(param, idx) => cpy.ValDef(param)(rhs = selector(idx))
598+
}
599+
Function(param :: Nil, Block(vdefs, body))
600+
}
601+
583602
/** Add annotation with class `cls` to tree:
584603
* tree @cls
585604
*/

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -583,26 +583,44 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
583583
if (protoFormals.length == params.length) protoFormals(i)
584584
else errorType(i"wrong number of parameters, expected: ${protoFormals.length}", tree.pos)
585585

586-
val inferredParams: List[untpd.ValDef] =
587-
for ((param, i) <- params.zipWithIndex) yield
588-
if (!param.tpt.isEmpty) param
589-
else cpy.ValDef(param)(
590-
tpt = untpd.TypeTree(
591-
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
592-
593-
// Define result type of closure as the expected type, thereby pushing
594-
// down any implicit searches. We do this even if the expected type is not fully
595-
// defined, which is a bit of a hack. But it's needed to make the following work
596-
// (see typers.scala and printers/PlainPrinter.scala for examples).
597-
//
598-
// def double(x: Char): String = s"$x$x"
599-
// "abc" flatMap double
600-
//
601-
val resultTpt = protoResult match {
602-
case WildcardType(_) => untpd.TypeTree()
603-
case _ => untpd.TypeTree(protoResult)
586+
/** Is `formal` a product type which is elementwise compatible with `params`? */
587+
def ptIsCorrectProduct(formal: Type) = {
588+
val pclass = defn.ProductNClass(params.length)
589+
isFullyDefined(formal, ForceDegree.noBottom) &&
590+
formal.derivesFrom(pclass) &&
591+
formal.baseArgTypes(pclass).corresponds(params) {
592+
(argType, param) =>
593+
param.tpt.isEmpty || isCompatible(argType, typedAheadType(param.tpt).tpe)
594+
}
604595
}
605-
typed(desugar.makeClosure(inferredParams, fnBody, resultTpt), pt)
596+
597+
val desugared =
598+
if (protoFormals.length == 1 && params.length != 1 && ptIsCorrectProduct(protoFormals.head)) {
599+
desugar.makeUnaryCaseLambda(params, fnBody)
600+
}
601+
else {
602+
val inferredParams: List[untpd.ValDef] =
603+
for ((param, i) <- params.zipWithIndex) yield
604+
if (!param.tpt.isEmpty) param
605+
else cpy.ValDef(param)(
606+
tpt = untpd.TypeTree(
607+
inferredParamType(param, protoFormal(i)).underlyingIfRepeated(isJava = false)))
608+
609+
// Define result type of closure as the expected type, thereby pushing
610+
// down any implicit searches. We do this even if the expected type is not fully
611+
// defined, which is a bit of a hack. But it's needed to make the following work
612+
// (see typers.scala and printers/PlainPrinter.scala for examples).
613+
//
614+
// def double(x: Char): String = s"$x$x"
615+
// "abc" flatMap double
616+
//
617+
val resultTpt = protoResult match {
618+
case WildcardType(_) => untpd.TypeTree()
619+
case _ => untpd.TypeTree(protoResult)
620+
}
621+
desugar.makeClosure(inferredParams, fnBody, resultTpt)
622+
}
623+
typed(desugared, pt)
606624
}
607625
}
608626

test/dotc/tests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class tests extends CompilerTest {
109109
@Test def neg_abstractOverride() = compileFile(negDir, "abstract-override", xerrors = 2)
110110
@Test def neg_blockescapes() = compileFile(negDir, "blockescapesNeg", xerrors = 1)
111111
@Test def neg_bounds() = compileFile(negDir, "bounds", xerrors = 2)
112+
@Test def neg_functionArity() = compileFile(negDir, "function-arity", xerrors = 5)
112113
@Test def neg_typedapply() = compileFile(negDir, "typedapply", xerrors = 4)
113114
@Test def neg_typedIdents() = compileDir(negDir, "typedIdents", xerrors = 2)
114115
@Test def neg_assignments() = compileFile(negDir, "assignments", xerrors = 3)

tests/neg/function-arity.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
object Test {
2+
3+
// From #873:
4+
5+
trait X extends Function1[Int, String]
6+
implicit def f2x(f: Function1[Int, String]): X = ???
7+
({case _ if "".isEmpty => 0} : X) // error: expected String, found Int
8+
9+
// Tests where parameter list cannot be made into a pattern
10+
11+
def unary[T](x: T => Unit) = ???
12+
unary((x, y) => ()) // error
13+
14+
unary[(Int, Int)]((x, y) => ())
15+
16+
unary[(Int, Int)](() => ()) // error
17+
unary[(Int, Int)]((x, y, _) => ()) // error
18+
19+
unary[(Int, Int)]((x: String, y) => ()) // error
20+
21+
22+
}

tests/pos/i873.scala renamed to tests/pos/function-arity.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,12 @@ object Test {
77
({case _ if "".isEmpty => ""} : X) // allowed, implicit view used to adapt
88

99
// ({case _ if "".isEmpty => 0} : X) // expected String, found Int
10+
11+
def unary[T](a: T, b: T, f: ((T, T)) => T): T = f((a, b))
12+
unary(1, 2, (x, y) => x)
13+
unary(1, 2, (x: Int, y) => x)
14+
unary(1, 2, (x: Int, y: Float) => x)
15+
16+
val xs = List(1, 2, 3)
17+
xs.zipWithIndex.map(_ + _)
1018
}

0 commit comments

Comments
 (0)