Skip to content

Commit e91afd4

Browse files
committed
Allow partial eta expansion
Allow partial eta expansion using the following rule: Assume we have an application (x_1, ..., x_m) => f(<arg_1, ..., arg_n>) every `x_i` occurs exactly once as argument to `f`, and typing `f` with `(?, ..., ?) => ?` as expected type (where there are n occurrences of `?` in the argument list) yields a method type `(T_1, ..., T_n)R`. Let `p` be the mapping from `1..m` to `1..n` which maps every parameter `x_i` to the position where it occurs in `<arg_1, ..., arg_n>`. In this case, if one of the parameter types `x_i` is not given and not inferred from the expected type, assume `T_p(i)` as the type of `x_i`, provided `T_p(i)` is fully defined and not a repeated type `T*` or `=>T*`. The reason for excluding repeated types `T*` is that `T*` is not a valid type for a lambda parameter.
1 parent a082b9c commit e91afd4

File tree

4 files changed

+84
-24
lines changed

4 files changed

+84
-24
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,13 @@ object Types {
10441044
case _ => NoType
10451045
}
10461046

1047-
/** If this is a FunProto or PolyProto, WildcardType, otherwise this. */
1047+
/** If this is a repeated type, its element type, otherwise the type itself */
1048+
def repeatedToSingle(implicit ctx: Context): Type = this match {
1049+
case tp @ ExprType(tp1) => tp.derivedExprType(tp1.repeatedToSingle)
1050+
case _ => if (isRepeatedParam) this.argTypesHi.head else this
1051+
}
1052+
1053+
/** If this is a FunProto or PolyProto, WildcardType, otherwise this. */
10481054
def notApplied: Type = this
10491055

10501056
// ----- Normalizing typerefs over refined types ----------------------------

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,12 +1056,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10561056
*/
10571057
def isAsSpecific(alt1: TermRef, tp1: Type, alt2: TermRef, tp2: Type): Boolean = ctx.traceIndented(i"isAsSpecific $tp1 $tp2", overload) { tp1 match {
10581058
case tp1: MethodType => // (1)
1059-
def repeatedToSingle(tp: Type): Type = tp match {
1060-
case tp @ ExprType(tp1) => tp.derivedExprType(repeatedToSingle(tp1))
1061-
case _ => if (tp.isRepeatedParam) tp.argTypesHi.head else tp
1062-
}
10631059
val formals1 =
1064-
if (tp1.isVarArgsMethod && tp2.isVarArgsMethod) tp1.paramInfos map repeatedToSingle
1060+
if (tp1.isVarArgsMethod && tp2.isVarArgsMethod) tp1.paramInfos.map(_.repeatedToSingle)
10651061
else tp1.paramInfos
10661062
isApplicable(alt2, formals1, WildcardType) ||
10671063
tp1.paramInfos.isEmpty && tp2.isInstanceOf[LambdaType]

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -724,41 +724,70 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
724724
*/
725725
var fnBody = tree.body
726726

727+
/** A map from parameter names to unique positions where the parameter
728+
* appears in the argument list of an application.
729+
*/
730+
var paramIndex = Map[Name, Int]()
731+
732+
/** If parameter `param` appears exactly once as an argument in `args`,
733+
* the singleton list consisting of its position in `args`, otherwise `Nil`.
734+
*/
735+
def paramIndices(param: untpd.ValDef, args: List[untpd.Tree], start: Int): List[Int] = args match {
736+
case arg :: args1 =>
737+
if (refersTo(arg, param))
738+
if (paramIndices(param, args1, start + 1).isEmpty) start :: Nil
739+
else Nil
740+
else paramIndices(param, args1, start + 1)
741+
case _ => Nil
742+
}
743+
727744
/** If function is of the form
728-
* (x1, ..., xN) => f(x1, ..., XN)
729-
* the type of `f`, otherwise NoType. (updates `fnBody` as a side effect).
745+
* (x1, ..., xN) => f(... x1, ..., XN, ...)
746+
* where each `xi` occurs exactly once in the argument list of `f` (in
747+
* any order), the type of `f`, otherwise NoType.
748+
* Updates `fnBody` and `paramIndex` as a side effect.
749+
* @post: If result exists, `paramIndex` is defined for the name of
750+
* every parameter in `params`.
730751
*/
731752
def calleeType: Type = fnBody match {
732-
case Apply(expr, args) if (args corresponds params)(refersTo) =>
733-
expr match {
734-
case untpd.TypedSplice(expr1) =>
735-
expr1.tpe
736-
case _ =>
737-
val protoArgs = args map (_ withType WildcardType)
738-
val callProto = FunProto(protoArgs, WildcardType, this)
739-
val expr1 = typedExpr(expr, callProto)
740-
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
741-
expr1.tpe
742-
}
753+
case Apply(expr, args) =>
754+
paramIndex = {
755+
for (param <- params; idx <- paramIndices(param, args, 0))
756+
yield param.name -> idx
757+
}.toMap
758+
if (paramIndex.size == params.length)
759+
expr match {
760+
case untpd.TypedSplice(expr1) =>
761+
expr1.tpe
762+
case _ =>
763+
val protoArgs = args map (_ withType WildcardType)
764+
val callProto = FunProto(protoArgs, WildcardType, this)
765+
val expr1 = typedExpr(expr, callProto)
766+
fnBody = cpy.Apply(fnBody)(untpd.TypedSplice(expr1), args)
767+
expr1.tpe
768+
}
769+
else NoType
743770
case _ =>
744771
NoType
745772
}
746773

747774
/** Two attempts: First, if expected type is fully defined pick this one.
748775
* Second, if function is of the form
749-
* (x1, ..., xN) => f(x1, ..., XN)
750-
* and f has a method type MT, pick the corresponding parameter type in MT,
751-
* if this one is fully defined.
776+
* (x1, ..., xN) => f(... x1, ..., XN, ...)
777+
* where each `xi` occurs exactly once in the argument list of `f` (in
778+
* any order), and f has a method type MT, pick the corresponding parameter
779+
* type in MT, if this one is fully defined.
752780
* If both attempts fail, issue a "missing parameter type" error.
753781
*/
754782
def inferredParamType(param: untpd.ValDef, formal: Type): Type = {
755783
if (isFullyDefined(formal, ForceDegree.noBottom)) return formal
756784
calleeType.widen match {
757785
case mtpe: MethodType =>
758-
val pos = params indexWhere (_.name == param.name)
786+
val pos = paramIndex(param.name)
759787
if (pos < mtpe.paramInfos.length) {
760788
val ptype = mtpe.paramInfos(pos)
761-
if (isFullyDefined(ptype, ForceDegree.noBottom)) return ptype
789+
if (isFullyDefined(ptype, ForceDegree.noBottom) && !ptype.isRepeatedParam)
790+
return ptype
762791
}
763792
case _ =>
764793
}

tests/pos/i2570.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
object Test {
2+
3+
def repeat(s: String, i: Int, j: Int = 22) = s * i
4+
5+
val f1 = repeat("abc", _)
6+
val f2: Int => String = f1
7+
val f3 = repeat(_, 3)
8+
val f4: String => String = f3
9+
val f5 = repeat("abc", _, _)
10+
val f6: (Int, Int) => String = f5
11+
val f7 = repeat(_, 11, _)
12+
val f8: (String, Int) => String = f7
13+
14+
def sum(x: Int, y: => Int) = x + y
15+
16+
val g1 = sum(2, _)
17+
val g2: (=> Int) => Int = g1
18+
19+
val h0: ((Int, => Int) => Int) = sum
20+
21+
def sum2(x: Int, ys: Int*) = (x /: ys)(_ + _)
22+
val h1: ((Int, Seq[Int]) => Int) = sum2
23+
24+
// Not yet:
25+
// val h1 = repeat
26+
// val h2: (String, Int, Int) = h1
27+
// val h3 = sum
28+
// val h4: (Int, => Int) = h3
29+
}

0 commit comments

Comments
 (0)