Skip to content

Commit 6c164a5

Browse files
committed
Extend argument pretyping to case-closures
1 parent f37b2a1 commit 6c164a5

File tree

4 files changed

+32
-37
lines changed

4 files changed

+32
-37
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped]
287287
case ValDef(_, tpt, _) => tpt.isEmpty
288288
case _ => false
289289
}
290+
case Match(EmptyTree, _) =>
291+
true
290292
case _ => false
291293
}
292294

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,13 +1272,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
12721272
def narrowBySize(alts: List[TermRef]): List[TermRef] =
12731273
alts filter (alt => sizeFits(alt, alt.widen))
12741274

1275-
def isFunArg(arg: untpd.Tree) = arg match {
1276-
case untpd.Function(_, _) | Match(EmptyTree, _) => true
1277-
case _ => false
1278-
}
1279-
12801275
def narrowByShapes(alts: List[TermRef]): List[TermRef] = {
1281-
if (normArgs exists isFunArg)
1276+
if (normArgs exists untpd.isFunctionWithUnknownParamType)
12821277
if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType)
12831278
else narrowByTypes(alts, normArgs map typeShape, resultType)
12841279
else
@@ -1358,33 +1353,31 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
13581353
case ValDef(_, tpt, _) => tpt.isEmpty
13591354
case _ => false
13601355
}
1361-
arg match {
1362-
case arg: untpd.Function if arg.args.exists(isUnknownParamType) =>
1363-
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1364-
val formalsForArg: List[Type] = altFormals.map(_.head)
1365-
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1366-
// (p_1_1, ..., p_m_1) => r_1
1367-
// ...
1368-
// (p_1_n, ..., p_m_n) => r_n
1369-
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
1370-
formalsForArg.map(defn.FunctionOf.unapply)
1371-
if (decomposedFormalsForArg.forall(_.isDefined)) {
1372-
val formalParamTypessForArg: List[List[Type]] =
1373-
decomposedFormalsForArg.map(_.get._1)
1374-
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1375-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1376-
// Given definitions above, for i = 1,...,m,
1377-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1378-
// If all p_i_k's are the same, assume the type as formal parameter
1379-
// type of the i'th parameter of the closure.
1380-
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1381-
else WildcardType)
1382-
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1383-
overload.println(i"pretype arg $arg with expected type $commonFormal")
1384-
pt.typedArg(arg, commonFormal)
1385-
}
1356+
if (untpd.isFunctionWithUnknownParamType(arg)) {
1357+
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1358+
val formalsForArg: List[Type] = altFormals.map(_.head)
1359+
// For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
1360+
// (p_1_1, ..., p_m_1) => r_1
1361+
// ...
1362+
// (p_1_n, ..., p_m_n) => r_n
1363+
val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] =
1364+
formalsForArg.map(defn.FunctionOf.unapply)
1365+
if (decomposedFormalsForArg.forall(_.isDefined)) {
1366+
val formalParamTypessForArg: List[List[Type]] =
1367+
decomposedFormalsForArg.map(_.get._1)
1368+
if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1369+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1370+
// Given definitions above, for i = 1,...,m,
1371+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1372+
// If all p_i_k's are the same, assume the type as formal parameter
1373+
// type of the i'th parameter of the closure.
1374+
if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
1375+
else WildcardType)
1376+
val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
1377+
println(i"pretype arg $arg with expected type $commonFormal")
1378+
pt.typedArg(arg, commonFormal)
13861379
}
1387-
case _ =>
1380+
}
13881381
}
13891382
recur(altFormals.map(_.tail), args1)
13901383
case _ =>

tests/pos/inferOverloaded.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ object Test {
2828
m.map1({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])
2929
m.map2({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int])
3030

31-
// These ones did not work before, still don't work in dotty:
32-
//m.map1 { case (k, v) => k }
33-
//val r = m.map1 { case (k, v) => (k, k*10) }
34-
//val rt: MyMap[Int, Int] = r
35-
//m.foo { case (k, v) => k - 1 }
31+
// These ones did not work before:
32+
m.map1 { case (k, v) => k }
33+
val r = m.map1 { case (k, v) => (k, k*10) }
34+
val rt: MyMap[Int, Int] = r
35+
m.foo { case (k, v) => k - 1 }
3636

3737
// Used to be ambiguous but overload resolution now favors PartialFunction
3838
def h[R](pf: Function2[Int, String, R]): Unit = ()

0 commit comments

Comments
 (0)