Skip to content

Commit 5b2d7bf

Browse files
committed
Better expected type for arguments of overloaded methods
`pretypeArgs` allows arguments of overloaded methods to be typed with a more precise expected type when the formal parameter types of each overload are all compatible function types, but previously this logic only kicked in for arguments which were syntactically known to be functions themselves, which means that it worked when the argument was `foo(_)` or `x => foo(x)`, but not when it was just `foo`. This commit simply removes this restriction. Fixes #10325.
1 parent 3c04f9b commit 5b2d7bf

File tree

2 files changed

+51
-32
lines changed

2 files changed

+51
-32
lines changed

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,38 +2026,38 @@ trait Applications extends Compatibility {
20262026
private def pretypeArgs(alts: List[TermRef], pt: FunProto)(using Context): Unit = {
20272027
def recur(altFormals: List[List[Type]], args: List[untpd.Tree]): Unit = args match {
20282028
case arg :: args1 if !altFormals.exists(_.isEmpty) =>
2029-
untpd.functionWithUnknownParamType(arg) match {
2030-
case Some(fn) =>
2031-
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
2032-
val formalsForArg: List[Type] = altFormals.map(_.head)
2033-
def argTypesOfFormal(formal: Type): List[Type] =
2034-
formal match {
2035-
case defn.FunctionOf(args, result, isImplicit, isErased) => args
2036-
case defn.PartialFunctionOf(arg, result) => arg :: Nil
2037-
case _ => Nil
2038-
}
2039-
val formalParamTypessForArg: List[List[Type]] =
2040-
formalsForArg.map(argTypesOfFormal)
2041-
if (formalParamTypessForArg.forall(_.nonEmpty) &&
2042-
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
2043-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
2044-
// Given definitions above, for i = 1,...,m,
2045-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
2046-
// If all p_i_k's are the same, assume the type as formal parameter
2047-
// type of the i'th parameter of the closure.
2048-
if (isUniform(ps)(_ frozen_=:= _)) ps.head
2049-
else WildcardType)
2050-
def isPartial = // we should generate a partial function for the arg
2051-
fn.isInstanceOf[untpd.Match] &&
2052-
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
2053-
val commonFormal =
2054-
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
2055-
else defn.FunctionOf(commonParamTypes, WildcardType)
2056-
overload.println(i"pretype arg $arg with expected type $commonFormal")
2057-
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom)))
2058-
withMode(Mode.ImplicitsEnabled)(pt.typedArg(arg, commonFormal))
2059-
}
2060-
case None =>
2029+
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
2030+
val formalsForArg: List[Type] = altFormals.map(_.head)
2031+
def argTypesOfFormal(formal: Type): List[Type] =
2032+
formal match {
2033+
case defn.FunctionOf(args, result, isImplicit, isErased) => args
2034+
case defn.PartialFunctionOf(arg, result) => arg :: Nil
2035+
case _ => Nil
2036+
}
2037+
val formalParamTypessForArg: List[List[Type]] =
2038+
formalsForArg.map(argTypesOfFormal)
2039+
if (formalParamTypessForArg.forall(_.nonEmpty) &&
2040+
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
2041+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
2042+
// Given definitions above, for i = 1,...,m,
2043+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
2044+
// If all p_i_k's are the same, assume the type as formal parameter
2045+
// type of the i'th parameter of the closure.
2046+
if (isUniform(ps)(_ frozen_=:= _)) ps.head
2047+
else WildcardType)
2048+
/** Should we generate a partial function for the arg ? */
2049+
def isPartial = untpd.functionWithUnknownParamType(arg) match
2050+
case Some(fn) =>
2051+
fn.isInstanceOf[untpd.Match] &&
2052+
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
2053+
case None =>
2054+
false
2055+
val commonFormal =
2056+
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
2057+
else defn.FunctionOf(commonParamTypes, WildcardType)
2058+
overload.println(i"pretype arg $arg with expected type $commonFormal")
2059+
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom)))
2060+
withMode(Mode.ImplicitsEnabled)(pt.typedArg(arg, commonFormal))
20612061
}
20622062
recur(altFormals.map(_.tail), args1)
20632063
case _ =>

tests/pos/i10325.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
object Test {
2+
def nullToNone[K, V](tuple: (K, V)): (K, Option[V]) = {
3+
val (k, v) = tuple
4+
(k, Option(v))
5+
}
6+
7+
def test: Unit = {
8+
val scalaMap: Map[String, String] = Map()
9+
10+
val a = scalaMap.map(nullToNone)
11+
val a1: Map[String, Option[String]] = a
12+
13+
val b = scalaMap.map(nullToNone(_))
14+
val b1: Map[String, Option[String]] = b
15+
16+
val c = scalaMap.map(x => nullToNone(x))
17+
val c1: Map[String, Option[String]] = c
18+
}
19+
}

0 commit comments

Comments
 (0)