Skip to content

Commit 7dfab5f

Browse files
committed
Refine pretypeArgs
It worked more or less by accident before. Now it's more complicated, but we also have tests.
1 parent f63110e commit 7dfab5f

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,17 +1151,38 @@ trait Applications extends Compatibility { self: Typer =>
11511151
}
11521152
arg match {
11531153
case arg: untpd.Function if arg.args.exists(isUnknownParamType) =>
1154-
val commonFormal = altFormals.map(_.head).reduceLeft(_ | _)
1155-
overload.println(i"pretype arg $arg with expected type $commonFormal")
1156-
pt.typedArg(arg, commonFormal)
1154+
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1155+
val formalsForArg: List[Type] = altFormals.map(_.head)
1156+
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1157+
// (p_1_1, ..., p_m_1) => r_1
1158+
// ...
1159+
// (p_1_n, ..., p_m_n) => r_n
1160+
val decomposedFormalsForArg: List[Option[(List[Type], Type)]] =
1161+
formalsForArg.map(defn.FunctionOf.unapply)
1162+
if (decomposedFormalsForArg.forall(_.isDefined)) {
1163+
val formalParamTypessForArg: List[List[Type]] =
1164+
decomposedFormalsForArg.map(_.get._1)
1165+
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1166+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1167+
// Given definitions above, for i = 1,...,m,
1168+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1169+
// If all p_i_k's are the same, assume the type as formal parameter
1170+
// type of the i'th parameter of the closure.
1171+
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1172+
else WildcardType)
1173+
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1174+
overload.println(i"pretype arg $arg with expected type $commonFormal")
1175+
pt.typedArg(arg, commonFormal)
1176+
}
1177+
}
11571178
case _ =>
11581179
}
11591180
recur(altFormals.map(_.tail), args1)
11601181
case _ =>
11611182
}
11621183
def paramTypes(alt: Type): List[Type] = alt match {
11631184
case mt: MethodType => mt.paramTypes
1164-
case mt: PolyType => paramTypes(mt.resultType).map(wildApprox(_))
1185+
case mt: PolyType => paramTypes(mt.resultType)
11651186
case _ => Nil
11661187
}
11671188
recur(alts.map(alt => paramTypes(alt.widen)), pt.args)

tests/neg/overloaded.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// testing the limits of parameter type inference
2+
3+
object Test {
4+
def mapX(f: Char => Char): String = ???
5+
def mapX[U](f: U => U): U = ???
6+
mapX(x => x) // error: missing parameter type
7+
8+
def foo(f: Char => Char): Unit = ???
9+
def foo(f: Int => Int): String = ???
10+
foo(x => x) // error: missing parameter type
11+
12+
def bar(f: (Char, Char) => Unit): Unit = ???
13+
def bar(f: Char => Unit) = ???
14+
bar((x, y) => ())
15+
bar (x => ())
16+
17+
}

tests/pos/overloaded.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,27 @@ object overloaded {
2424

2525
def map(f: Char => Char): String = ???
2626
def map[U](f: Char => U): Seq[U] = ???
27-
map(x => x.toUpper)
27+
val r1 = map(x => x.toUpper)
28+
val t1: String = r1
29+
val r2 = map(x => x.toInt)
30+
val t2: Seq[Int] = r2
31+
32+
def flatMap(f: Char => String): String = ???
33+
def flatMap[U](f: Char => Seq[U]): Seq[U] = ???
34+
val r3 = flatMap(x => x.toString)
35+
val t3: String = r3
36+
val r4 = flatMap(x => List(x))
37+
val t4: Seq[Char] = r4
38+
39+
def bar(f: (Char, Char) => Unit): Unit = ???
40+
def bar(f: Char => Unit) = ???
41+
bar((x, y) => ())
42+
bar (x => ())
43+
44+
def combine(f: (Char, Int) => Int): Int = ???
45+
def combine(f: (String, Int) => String): String = ???
46+
val r5 = combine((x: Char, y) => x + y)
47+
val t5: Int = r5
48+
val r6 = combine((x: String, y) => x ++ y.toString)
49+
val t6: String = r6
2850
}

0 commit comments

Comments
 (0)