Skip to content

Commit 1564329

Browse files
committed
Avoid inference getting stuck when the expected type contains a union/intersection
When we type a method call, we infer constraints based on its expected type before typing its arguments. This way, we can type these arguments with a precise expected type. This works fine as long as the constraints we infer based on the expected type are _necessary_ constraints, but in general type inference can go further and infer _sufficient_ constraints, meaning that we might get stuck with a set of constraints which does not allow the method arguments to be typed at all. Since 8067b95 we work around the problem by simply not propagating any constraint when the expected type is a union, but this solution is incomplete: - It only handles unions at the top-level, but the same problem can happen with unions in any covariant position (method b of or-inf.scala) as well as intersections in contravariant positions (and-inf.scala, i8378.scala) - Even when a union appear at the top-level, there might be constraints we can propagate, for example if only one branch can possibly match (method c of or-inf.scala) Thankfully, we already have a solution that works for all these problems: `TypeComparer#either` is capable of inferring only necessary constraints. So far, this was only done when inferring GADT bounds to preserve soundness, this commit extends this to use the same logic when constraining a method based on its expected type. Fixes #8378 which I previously thought was unfixable :).
1 parent 0d03af4 commit 1564329

File tree

9 files changed

+75
-23
lines changed

9 files changed

+75
-23
lines changed

compiler/src/dotty/tools/dotc/core/Mode.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ object Mode {
6060
*/
6161
val Printing: Mode = newMode(10, "Printing")
6262

63+
/** We are constraining a method based on its expected type. */
64+
val ConstrainResult: Mode = newMode(11, "ConstrainResult")
65+
6366
/** We are currently in a `viewExists` check. In that case, ambiguous
6467
* implicits checks are disabled and we succeed with the first implicit
6568
* found.

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,14 +1364,25 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
13641364

13651365
/** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints.
13661366
*
1367-
* If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints.
1368-
* If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all.
1367+
* If we're inferring GADT bounds or constraining a method based on its
1368+
* expected type, we infer only the _necessary_ constraints, this means we
1369+
* keep the smaller constraint if any, or no constraint at all. This is
1370+
* necessary for GADT bounds inference to be sound. When constraining a
1371+
* method, this avoid painting ourselves into a corner we cannot backtrack
1372+
* out of, see or-inf.scala and and-inf.scala for examples.
13691373
*
1374+
* Otherwise, we infer _sufficient_ constraints: we try to keep the smaller of
1375+
* the two constraints, but if never is smaller than the other, we just pick
1376+
* the first one.
1377+
*
1378+
* @see [[necessaryEither]] for the GADT / result type case
13701379
* @see [[sufficientEither]] for the normal case
1371-
* @see [[necessaryEither]] for the GADTFlexible case
13721380
*/
13731381
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
1374-
if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
1382+
if ctx.mode.is(Mode.GadtConstraintInference) || ctx.mode.is(Mode.ConstrainResult) then
1383+
necessaryEither(op1, op2)
1384+
else
1385+
sufficientEither(op1, op2)
13751386

13761387
/** Returns true iff the result of evaluating either `op1` or `op2` is true,
13771388
* trying at the same time to keep the constraint as wide as possible.
@@ -1438,8 +1449,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14381449
* T1 & T2 <:< T3
14391450
* T1 <:< T2 | T3
14401451
*
1441-
* Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT
1442-
* constraints that necessarily follow from the subtyping relationship. For instance, if we have
1452+
* Unlike [[sufficientEither]], this method is used in GADTConstraintInference mode, when we are attempting
1453+
* to infer GADT constraints that necessarily follow from the subtyping relationship. For instance, if we have
14431454
*
14441455
* enum Expr[T] {
14451456
* case IntExpr(i: Int) extends Expr[Int]
@@ -1466,6 +1477,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14661477
*
14671478
* then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive.
14681479
*
1480+
* This method is also used in ConstrainResult mode to avoid inference getting stuck due to the lack of backtracking,
1481+
* see or-inf.scala and and-inf.scala for examples.
1482+
*
14691483
* Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both
14701484
* subtyping relationships.
14711485
*/

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,14 @@ object ProtoTypes {
5959
else ctx.test(testCompat)
6060
}
6161

62-
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = pt.dealias match {
63-
case _: OrType => true
64-
// Don't constrain results with union types, since comparison with a union
65-
// type on the right might commit too early into one side.
66-
case pt => pt.isRef(defn.UnitClass)
67-
}
62+
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean =
63+
pt.dealias.isRef(defn.UnitClass)
6864

6965
/** Check that the result type of the current method
7066
* fits the given expected result type.
7167
*/
72-
def constrainResult(mt: Type, pt: Type)(implicit ctx: Context): Boolean = {
68+
def constrainResult(mt: Type, pt: Type)(implicit parentCtx: Context): Boolean = {
69+
given ctx as Context = parentCtx.addMode(Mode.ConstrainResult)
7370
val savedConstraint = ctx.typerState.constraint
7471
val res = pt.widenExpr match {
7572
case pt: FunProto =>

tests/neg/union.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ object O {
1717

1818
val x: A = f(new A { }, new A)
1919

20-
val y1: A | B = f(new A { }, new B) // error
20+
val y1: A | B = f(new A { }, new B) // ok
2121
val y2: A | B = f[A | B](new A { }, new B) // ok
2222

2323
val z = if (???) new A{} else new B

tests/pos/and-inf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class A
2+
class B
3+
4+
class Inv[T]
5+
class Contra[-T]
6+
7+
class Test {
8+
def foo[T, S](x: T, y: S): Contra[Inv[T] & Inv[S]] = ???
9+
val a: A = new A
10+
val b: B = new B
11+
12+
val x: Contra[Inv[A] & Inv[B]] = foo(a, b)
13+
}

tests/neg/i6565.scala renamed to tests/pos/i6565.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ def (o: Lifted[O]) flatMap [O, U] (f: O => Lifted[U]): Lifted[U] = ???
99
val error: Err = Err()
1010

1111
lazy val ok: Lifted[String] = { // ok despite map returning a union
12-
point("a").map(_ => if true then "foo" else error) // error
12+
point("a").map(_ => if true then "foo" else error) // ok
1313
}
1414

1515
lazy val bad: Lifted[String] = { // found Lifted[Object]
16-
point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // error
17-
}
16+
point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // ok
17+
}

tests/pos/i8378.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
trait Has[A]
2+
3+
trait A
4+
trait B
5+
trait C
6+
7+
trait ZLayer[-RIn, +E, +ROut]
8+
9+
object ZLayer {
10+
def fromServices[A0, A1, B](f: (A0, A1) => B): ZLayer[Has[A0] with Has[A1], Nothing, Has[B]] =
11+
???
12+
}
13+
14+
val live: ZLayer[Has[A] & Has[B], Nothing, Has[C]] =
15+
ZLayer.fromServices { (a: A, b: B) =>
16+
new C {}
17+
}

tests/pos/or-inf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
3+
def a(lis: Set[Int] | Set[String]) = {}
4+
a(Set(1))
5+
a(Set(""))
6+
7+
def b(lis: List[Set[Int] | Set[String]]) = {}
8+
b(List(Set(1)))
9+
b(List(Set("")))
10+
11+
def c(x: Set[Any] | Array[Any]) = {}
12+
c(Set(1))
13+
c(Array(1))
14+
}

tests/pos/orinf.scala

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)