Skip to content

Commit a1a89b6

Browse files
committed
Fix #7893: Less eager forcing of expected types for functions
Don't force expected parameter types if the callee type of a closure is known. In this case, we have info info to fill in parameter types from the callee.
1 parent 255a538 commit a1a89b6

File tree

2 files changed

+39
-31
lines changed

2 files changed

+39
-31
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -955,35 +955,19 @@ class Typer extends Namer
955955
case _ => false
956956
}
957957

958-
pt match {
959-
case pt: TypeVar if untpd.isFunctionWithUnknownParamType(tree) =>
960-
// try to instantiate `pt` if this is possible. If it does not
961-
// work the error will be reported later in `inferredParam`,
962-
// when we try to infer the parameter type.
963-
isFullyDefined(pt, ForceDegree.noBottom)
964-
case _ =>
965-
}
966-
967-
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
958+
/** The function body to be returned in the closure. Can become a TypedSplice
959+
* of a typed expression if this is necessary to infer a parameter type.
960+
*/
961+
var fnBody = tree.body
968962

969963
def refersTo(arg: untpd.Tree, param: untpd.ValDef): Boolean = arg match {
970964
case Ident(name) => name == param.name
971965
case _ => false
972966
}
973967

974-
/** The function body to be returned in the closure. Can become a TypedSplice
975-
* of a typed expression if this is necessary to infer a parameter type.
976-
*/
977-
var fnBody = tree.body
978-
979-
/** A map from parameter names to unique positions where the parameter
980-
* appears in the argument list of an application.
981-
*/
982-
var paramIndex = Map[Name, Int]()
983-
984968
/** If parameter `param` appears exactly once as an argument in `args`,
985-
* the singleton list consisting of its position in `args`, otherwise `Nil`.
986-
*/
969+
* the singleton list consisting of its position in `args`, otherwise `Nil`.
970+
*/
987971
def paramIndices(param: untpd.ValDef, args: List[untpd.Tree]): List[Int] = {
988972
def loop(args: List[untpd.Tree], start: Int): List[Int] = args match {
989973
case arg :: args1 =>
@@ -995,15 +979,20 @@ class Typer extends Namer
995979
if (allIndices.length == 1) allIndices else Nil
996980
}
997981

998-
/** If function is of the form
999-
* (x1, ..., xN) => f(... x1, ..., XN, ...)
1000-
* where each `xi` occurs exactly once in the argument list of `f` (in
1001-
* any order), the type of `f`, otherwise NoType.
1002-
* Updates `fnBody` and `paramIndex` as a side effect.
1003-
* @post: If result exists, `paramIndex` is defined for the name of
1004-
* every parameter in `params`.
1005-
*/
1006-
def calleeType: Type = fnBody match {
982+
/** A map from parameter names to unique positions where the parameter
983+
* appears in the argument list of an application.
984+
*/
985+
var paramIndex = Map[Name, Int]()
986+
987+
/** If function is of the form
988+
* (x1, ..., xN) => f(... x1, ..., XN, ...)
989+
* where each `xi` occurs exactly once in the argument list of `f` (in
990+
* any order), the type of `f`, otherwise NoType.
991+
* Updates `fnBody` and `paramIndex` as a side effect.
992+
* @post: If result exists, `paramIndex` is defined for the name of
993+
* every parameter in `params`.
994+
*/
995+
lazy val calleeType: Type = fnBody match {
1007996
case app @ Apply(expr, args) =>
1008997
paramIndex = {
1009998
for (param <- params; idx <- paramIndices(param, args))
@@ -1025,6 +1014,18 @@ class Typer extends Namer
10251014
NoType
10261015
}
10271016

1017+
pt match {
1018+
case pt: TypeVar
1019+
if untpd.isFunctionWithUnknownParamType(tree) && !calleeType.exists =>
1020+
// try to instantiate `pt` if this is possible. If it does not
1021+
// work the error will be reported later in `inferredParam`,
1022+
// when we try to infer the parameter type.
1023+
isFullyDefined(pt, ForceDegree.noBottom)
1024+
case _ =>
1025+
}
1026+
1027+
val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
1028+
10281029
/** Two attempts: First, if expected type is fully defined pick this one.
10291030
* Second, if function is of the form
10301031
* (x1, ..., xN) => f(... x1, ..., XN, ...)

tests/pos/i7893.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
object Test {
2+
val l1 = List(Predef.identity[Int](_))
3+
val lc1: List[Int => Int] = l1
4+
5+
val l2 = List(Predef.identity[Int](_))
6+
val lc2: List[Int => Int] = l2
7+
}

0 commit comments

Comments
 (0)