Skip to content

Commit 03792ab

Browse files
committed
Avoid forcing lambda parameter types in overloading resolution
When taking the `typedArgs` of a `FunProto` we now always avoid failing on untyped parameters of lambdas. Instead we give the parameter a ? type and continue. This allows to use lambdas with untyped parameters in more situations than before.
1 parent 284d270 commit 03792ab

File tree

9 files changed

+78
-62
lines changed

9 files changed

+78
-62
lines changed

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
832832
case funRef: TermRef =>
833833
val app =
834834
if (proto.allArgTypesAreCurrent())
835-
new ApplyToTyped(tree, fun1, funRef, proto.typedArgs, pt)
835+
new ApplyToTyped(tree, fun1, funRef, proto.unforcedTypedArgs, pt)
836836
else
837837
new ApplyToUntyped(tree, fun1, funRef, proto, pt)(argCtx(tree))
838838
convertNewGenericArray(app.result)
@@ -857,7 +857,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
857857
}
858858

859859
fun1.tpe match {
860-
case err: ErrorType => cpy.Apply(tree)(fun1, proto.typedArgs).withType(err)
860+
case err: ErrorType => cpy.Apply(tree)(fun1, proto.unforcedTypedArgs).withType(err)
861861
case TryDynamicCallType => typedDynamicApply(tree, pt)
862862
case _ =>
863863
if (originalProto.isDropped) fun1
@@ -1604,7 +1604,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
16041604
if (isDetermined(alts2)) alts2
16051605
else {
16061606
pretypeArgs(alts2, pt)
1607-
narrowByTrees(alts2, pt.typedArgs, resultType)
1607+
narrowByTrees(alts2, pt.unforcedTypedArgs, resultType)
16081608
}
16091609
}
16101610

@@ -1665,7 +1665,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
16651665
else pt match {
16661666
case pt @ FunProto(_, resType: FunProto) =>
16671667
// try to narrow further with snd argument list
1668-
val advanced = advanceCandidates(pt.typedArgs.tpes)
1668+
val advanced = advanceCandidates(pt.unforcedTypedArgs.tpes)
16691669
resolveOverloaded(advanced.map(_._1), resType, Nil) // resolve with candidates where first params are stripped
16701670
.map(advanced.toMap) // map surviving result(s) back to original candidates
16711671
case _ =>
@@ -1697,40 +1697,38 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
16971697
private def pretypeArgs(alts: List[TermRef], pt: FunProto)(implicit ctx: Context): Unit = {
16981698
def recur(altFormals: List[List[Type]], args: List[untpd.Tree]): Unit = args match {
16991699
case arg :: args1 if !altFormals.exists(_.isEmpty) =>
1700-
def isUnknownParamType(t: untpd.Tree) = t match {
1701-
case ValDef(_, tpt, _) => tpt.isEmpty
1702-
case _ => false
1703-
}
1704-
val fn = untpd.functionWithUnknownParamType(arg)
1705-
if (fn.isDefined) {
1706-
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1707-
val formalsForArg: List[Type] = altFormals.map(_.head)
1708-
def argTypesOfFormal(formal: Type): List[Type] =
1709-
formal match {
1710-
case defn.FunctionOf(args, result, isImplicit, isErased) => args
1711-
case defn.PartialFunctionOf(arg, result) => arg :: Nil
1712-
case _ => Nil
1700+
untpd.functionWithUnknownParamType(arg) match {
1701+
case Some(fn) =>
1702+
def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
1703+
val formalsForArg: List[Type] = altFormals.map(_.head)
1704+
def argTypesOfFormal(formal: Type): List[Type] =
1705+
formal match {
1706+
case defn.FunctionOf(args, result, isImplicit, isErased) => args
1707+
case defn.PartialFunctionOf(arg, result) => arg :: Nil
1708+
case _ => Nil
1709+
}
1710+
val formalParamTypessForArg: List[List[Type]] =
1711+
formalsForArg.map(argTypesOfFormal)
1712+
if (formalParamTypessForArg.forall(_.nonEmpty) &&
1713+
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1714+
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1715+
// Given definitions above, for i = 1,...,m,
1716+
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1717+
// If all p_i_k's are the same, assume the type as formal parameter
1718+
// type of the i'th parameter of the closure.
1719+
if (isUniform(ps)(_ frozen_=:= _)) ps.head
1720+
else WildcardType)
1721+
def isPartial = // we should generate a partial function for the arg
1722+
fn.isInstanceOf[untpd.Match] &&
1723+
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
1724+
val commonFormal =
1725+
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
1726+
else defn.FunctionOf(commonParamTypes, WildcardType)
1727+
overload.println(i"pretype arg $arg with expected type $commonFormal")
1728+
if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.noBottom)))
1729+
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
17131730
}
1714-
val formalParamTypessForArg: List[List[Type]] =
1715-
formalsForArg.map(argTypesOfFormal)
1716-
if (formalParamTypessForArg.forall(_.nonEmpty) &&
1717-
isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
1718-
val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
1719-
// Given definitions above, for i = 1,...,m,
1720-
// ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
1721-
// If all p_i_k's are the same, assume the type as formal parameter
1722-
// type of the i'th parameter of the closure.
1723-
if (isUniform(ps)(_ frozen_=:= _)) ps.head
1724-
else WildcardType)
1725-
def isPartial = // we should generate a partial function for the arg
1726-
fn.get.isInstanceOf[untpd.Match] &&
1727-
formalsForArg.exists(_.isRef(defn.PartialFunctionClass))
1728-
val commonFormal =
1729-
if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType)
1730-
else defn.FunctionOf(commonParamTypes, WildcardType)
1731-
overload.println(i"pretype arg $arg with expected type $commonFormal")
1732-
pt.typedArg(arg, commonFormal)(ctx.addMode(Mode.ImplicitsEnabled))
1733-
}
1731+
case None =>
17341732
}
17351733
recur(altFormals.map(_.tail), args1)
17361734
case _ =>

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -288,21 +288,28 @@ object ProtoTypes {
288288
private def cacheTypedArg(arg: untpd.Tree, typerFn: untpd.Tree => Tree, force: Boolean)(implicit ctx: Context): Tree = {
289289
var targ = state.typedArg(arg)
290290
if (targ == null) {
291-
if (!force && untpd.functionWithUnknownParamType(arg).isDefined)
292-
// If force = false, assume ? rather than reporting an error.
293-
// That way we don't cause a "missing parameter" error in `typerFn(arg)`
294-
targ = arg.withType(WildcardType)
295-
else {
296-
targ = typerFn(arg)
297-
if (!ctx.reporter.hasUnreportedErrors) {
298-
// FIXME: This can swallow warnings by updating the typerstate from a nested
299-
// context that gets discarded later. But we do have to update the
300-
// typerstate if there are no errors. If we also omitted the next two lines
301-
// when warning were emitted, `pos/t1756.scala` would fail when run with -feature.
302-
// It would produce an orphan type parameter for CI when pickling.
303-
state.typedArg = state.typedArg.updated(arg, targ)
304-
state.evalState = state.evalState.updated(arg, (ctx.typerState, ctx.typerState.constraint))
305-
}
291+
untpd.functionWithUnknownParamType(arg) match {
292+
case Some(untpd.Function(args, _)) if !force =>
293+
// If force = false, assume what we know about the parameter types rather than reporting an error.
294+
// That way we don't cause a "missing parameter" error in `typerFn(arg)`
295+
val paramTypes = args map {
296+
case ValDef(_, tpt, _) if !tpt.isEmpty => typer.typedType(tpt).typeOpt
297+
case _ => WildcardType
298+
}
299+
targ = arg.withType(defn.FunctionOf(paramTypes, WildcardType))
300+
case Some(_) if !force =>
301+
targ = arg.withType(WildcardType)
302+
case _ =>
303+
targ = typerFn(arg)
304+
if (!ctx.reporter.hasUnreportedErrors) {
305+
// FIXME: This can swallow warnings by updating the typerstate from a nested
306+
// context that gets discarded later. But we do have to update the
307+
// typerstate if there are no errors. If we also omitted the next two lines
308+
// when warning were emitted, `pos/t1756.scala` would fail when run with -feature.
309+
// It would produce an orphan type parameter for CI when pickling.
310+
state.typedArg = state.typedArg.updated(arg, targ)
311+
state.evalState = state.evalState.updated(arg, (ctx.typerState, ctx.typerState.constraint))
312+
}
306313
}
307314
}
308315
targ
@@ -314,7 +321,7 @@ object ProtoTypes {
314321
* with unknown parameter types - this will then cause a
315322
* "missing parameter type" error
316323
*/
317-
private def typedArgs(force: Boolean): List[Tree] =
324+
protected[this] def typedArgs(force: Boolean): List[Tree] =
318325
if (state.typedArgs.size == args.length) state.typedArgs
319326
else {
320327
val args1 = args.mapconserve(cacheTypedArg(_, typer.typed(_), force))
@@ -376,7 +383,7 @@ object ProtoTypes {
376383
derivedFunProto(args, tm(resultType), typer)
377384

378385
def fold[T](x: T, ta: TypeAccumulator[T])(implicit ctx: Context): T =
379-
ta(ta.foldOver(x, typedArgs.tpes), resultType)
386+
ta(ta.foldOver(x, unforcedTypedArgs.tpes), resultType)
380387

381388
override def deepenProto(implicit ctx: Context): FunProto = derivedFunProto(args, resultType.deepenProto, typer)
382389

@@ -390,7 +397,7 @@ object ProtoTypes {
390397
* [](args): resultType, where args are known to be typed
391398
*/
392399
class FunProtoTyped(args: List[tpd.Tree], resultType: Type)(typer: Typer, isContextual: Boolean)(implicit ctx: Context) extends FunProto(args, resultType)(typer, isContextual)(ctx) {
393-
override def typedArgs: List[tpd.Tree] = args
400+
override def typedArgs(force: Boolean): List[tpd.Tree] = args
394401
override def withContext(ctx: Context): FunProtoTyped = this
395402
}
396403

tests/neg/i1640.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
object Test extends App {
22
List(1, 2, 3) map (_ match { case x => x + 1 })
3-
List((1, 2)) x (_ match { case (x, z) => x + z }) // error // error // error
3+
List((1, 2)) x (_ match { case (x, z) => x + z }) // error
44
}

tests/neg/overloaded.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
object Test {
44
def mapX(f: Char => Char): String = ???
55
def mapX[U](f: U => U): U = ???
6-
mapX(x => x) // error: missing parameter type
6+
mapX(x => x) //OK
77

88
def foo(f: Char => Char): Unit = ???
99
def foo(f: Int => Int): String = ???
10-
foo(x => x) // error: missing parameter type
10+
foo(x => x) // error: ambiguous
1111

1212
def bar(f: (Char, Char) => Unit): Unit = ???
1313
def bar(f: Char => Unit) = ???

tests/neg/parser-stability-10.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ def unapply(i1: Int)(i6: List[Int]): Int = {
99
} // error
1010
object i5 {
1111
import collection.mutable._
12-
try { ??? mutable { case i1(i5, i3, i4) => i5 }} // error // error
12+
try { ??? mutable { case i1(i5, i3, i4) => i5 }} // error
1313
}
1414
// error

tests/neg/t6455.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ object O { def filter(p: Int => Boolean): O.type = this }
22

33
class Test {
44
// should not compile because we no longer rewrite withFilter => filter under -Xfuture
5-
O.withFilter(f => true) // error // error
5+
O.withFilter(f => true) // error
66
}

tests/neg/t7239.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ object Test {
1717
(implicit F0: NoImplicit): HasWithFilter = ???
1818
}
1919

20-
BrokenMethod().withFilter(_ => true) // error // error
20+
BrokenMethod().withFilter(_ => true) // error
2121
BrokenMethod().filter(_ => true) // ok
2222

2323
locally {
@@ -35,6 +35,6 @@ object Test {
3535
// `(B => Boolean)`. Only later during pickling does the
3636
// defensive check for erroneous types in the tree pick up
3737
// the problem.
38-
BrokenMethod().withFilter(x => true) // error // error
38+
BrokenMethod().withFilter(x => true) // error
3939
}
4040
}

tests/run/overloads.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ ok: Funcs.foo('a') = 2
1313
ok: Funcs.foo(97) = 3
1414
ok: M1.f(3) = 11
1515
ok: M2.f(3) = 22
16+
ok: M3.f("abc", _.length) = cba
17+
ok: M3.f(2, _ + 2) = 4

tests/run/overloads.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ object M2 {
3131
def f[A](x: A) = 22;
3232
}
3333

34+
object M3 {
35+
def f(x: Int, f: Int => Int) = f(x)
36+
def f(x: String, f: String => String) = f(x)
37+
}
38+
3439
object overloads {
3540

3641
def check(what: String, actual: Any, expected: Any): Unit = {
@@ -78,6 +83,10 @@ object overloads {
7883
// val y = new scala.collection.mutable.Stack[Int];
7984
// check("M1.f(" + y +")", M1.f(y), 12);
8085
// check("M2.f(" + y +")", M2.f(y), 21);
86+
87+
check("M3.f(\"abc\", _.length)", M3.f("abc", _.reverse), "cba")
88+
check("M3.f(2, _ + 2)", M3.f(2, _ + 2), 4)
89+
8190
}
8291

8392
}

0 commit comments

Comments
 (0)