Skip to content

Commit 3e888f2

Browse files
authored
Merge pull request #8635 from dotty-staging/result-inf
Avoid inference getting stuck when the expected type contains a union/intersection
2 parents ca0ef84 + fe5be59 commit 3e888f2

File tree

13 files changed

+151
-58
lines changed

13 files changed

+151
-58
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,10 @@ trait ConstraintHandling[AbstractContext] {
484484
* recording an isLess relationship instead (even though this is not implied
485485
* by the bound).
486486
*
487-
* Narrowing a constraint is better than widening it, because narrowing leads
488-
* to incompleteness (which we face anyway, see for instance eitherIsSubType)
489-
* but widening leads to unsoundness.
487+
* Normally, narrowing a constraint is better than widening it, because
488+
* narrowing leads to incompleteness (which we face anyway, see for
489+
* instance `TypeComparer#either`) but widening leads to unsoundness,
490+
* but note the special handling in `ConstrainResult` mode below.
490491
*
491492
* A test case that demonstrates the problem is i864.scala.
492493
* Turn Config.checkConstraintsSeparated on to get an accurate diagnostic
@@ -544,10 +545,23 @@ trait ConstraintHandling[AbstractContext] {
544545
case bound: TypeParamRef if constraint contains bound =>
545546
addParamBound(bound)
546547
case _ =>
548+
val savedConstraint = constraint
547549
val pbound = prune(bound)
548-
pbound.exists
549-
&& kindCompatible(param, pbound)
550-
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
550+
val constraintsNarrowed = constraint ne savedConstraint
551+
552+
val res =
553+
pbound.exists
554+
&& kindCompatible(param, pbound)
555+
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
556+
// If we're in `ConstrainResult` mode, we don't want to commit to a
557+
// set of constraints that would later prevent us from typechecking
558+
// arguments, so if `pruneParams` had to narrow the constraints, we
559+
// simply do not record any new constraint.
560+
// Unlike in `TypeComparer#either`, the same reasoning does not apply
561+
// to GADT mode because this code is never run on GADT constraints.
562+
if ctx.mode.is(Mode.ConstrainResult) && constraintsNarrowed then
563+
constraint = savedConstraint
564+
res
551565
}
552566
finally addConstraintInvocations -= 1
553567
}

compiler/src/dotty/tools/dotc/core/Mode.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ object Mode {
6060
*/
6161
val Printing: Mode = newMode(10, "Printing")
6262

63+
/** We are constraining a method based on its expected type. */
64+
val ConstrainResult: Mode = newMode(11, "ConstrainResult")
65+
6366
/** We are currently in a `viewExists` check. In that case, ambiguous
6467
* implicits checks are disabled and we succeed with the first implicit
6568
* found.

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,14 +1364,26 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
13641364

13651365
/** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints.
13661366
*
1367-
* If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints.
1368-
* If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all.
1367+
* If we're inferring GADT bounds or constraining a method based on its
1368+
* expected type, we infer only the _necessary_ constraints, this means we
1369+
* keep the smaller constraint if any, or no constraint at all. This is
1370+
* necessary for GADT bounds inference to be sound. When constraining a
1371+
* method, this avoid committing of constraints that would later prevent us
1372+
* from typechecking method arguments, see or-inf.scala and and-inf.scala for
1373+
* examples.
13691374
*
1375+
* Otherwise, we infer _sufficient_ constraints: we try to keep the smaller of
1376+
* the two constraints, but if never is smaller than the other, we just pick
1377+
* the first one.
1378+
*
1379+
* @see [[necessaryEither]] for the GADT / result type case
13701380
* @see [[sufficientEither]] for the normal case
1371-
* @see [[necessaryEither]] for the GADTFlexible case
13721381
*/
13731382
protected def either(op1: => Boolean, op2: => Boolean): Boolean =
1374-
if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2)
1383+
if ctx.mode.is(Mode.GadtConstraintInference) || ctx.mode.is(Mode.ConstrainResult) then
1384+
necessaryEither(op1, op2)
1385+
else
1386+
sufficientEither(op1, op2)
13751387

13761388
/** Returns true iff the result of evaluating either `op1` or `op2` is true,
13771389
* trying at the same time to keep the constraint as wide as possible.
@@ -1438,8 +1450,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14381450
* T1 & T2 <:< T3
14391451
* T1 <:< T2 | T3
14401452
*
1441-
* Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT
1442-
* constraints that necessarily follow from the subtyping relationship. For instance, if we have
1453+
* Unlike [[sufficientEither]], this method is used in GADTConstraintInference mode, when we are attempting
1454+
* to infer GADT constraints that necessarily follow from the subtyping relationship. For instance, if we have
14431455
*
14441456
* enum Expr[T] {
14451457
* case IntExpr(i: Int) extends Expr[Int]
@@ -1466,48 +1478,49 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
14661478
*
14671479
* then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive.
14681480
*
1481+
* This method is also used in ConstrainResult mode
1482+
* to avoid inference getting stuck due to lack of backtracking,
1483+
* see or-inf.scala and and-inf.scala for examples.
1484+
*
14691485
* Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both
14701486
* subtyping relationships.
14711487
*/
1472-
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = {
1488+
private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean =
14731489
val preConstraint = constraint
1474-
14751490
val preGadt = ctx.gadt.fresh
1476-
// if GADTflexible mode is on, we expect to always have a ProperGadtConstraint
1477-
val pre = preGadt.asInstanceOf[ProperGadtConstraint]
1478-
if (op1) {
1479-
val leftConstraint = constraint
1480-
val leftGadt = ctx.gadt.fresh
1491+
1492+
def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean =
1493+
subsumes(left, right, preConstraint) && preGadt.match
1494+
case preGadt: ProperGadtConstraint =>
1495+
preGadt.subsumes(leftGadt, rightGadt, preGadt)
1496+
case _ =>
1497+
true
1498+
1499+
if op1 then
1500+
val op1Constraint = constraint
1501+
val op1Gadt = ctx.gadt.fresh
14811502
constraint = preConstraint
14821503
ctx.gadt.restore(preGadt)
1483-
if (op2)
1484-
if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) {
1485-
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt")
1486-
constr.println(i"CUT - prefer $constraint over $leftConstraint")
1487-
true
1488-
}
1489-
else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) {
1490-
gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}")
1491-
constr.println(i"CUT - prefer $leftConstraint over $constraint")
1492-
constraint = leftConstraint
1493-
ctx.gadt.restore(leftGadt)
1494-
true
1495-
}
1496-
else {
1504+
if op2 then
1505+
if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then
1506+
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt")
1507+
constr.println(i"CUT - prefer $constraint over $op1Constraint")
1508+
else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then
1509+
gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}")
1510+
constr.println(i"CUT - prefer $op1Constraint over $constraint")
1511+
constraint = op1Constraint
1512+
ctx.gadt.restore(op1Gadt)
1513+
else
14971514
gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt")
14981515
constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint")
14991516
constraint = preConstraint
15001517
ctx.gadt.restore(preGadt)
1501-
true
1502-
}
1503-
else {
1504-
constraint = leftConstraint
1505-
ctx.gadt.restore(leftGadt)
1506-
true
1507-
}
1508-
}
1518+
else
1519+
constraint = op1Constraint
1520+
ctx.gadt.restore(op1Gadt)
1521+
true
15091522
else op2
1510-
}
1523+
end necessaryEither
15111524

15121525
/** Does type `tp1` have a member with name `name` whose normalized type is a subtype of
15131526
* the normalized type of the refinement `tp2`?

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,14 @@ object ProtoTypes {
5959
else ctx.test(testCompat)
6060
}
6161

62-
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = pt.dealias match {
63-
case _: OrType => true
64-
// Don't constrain results with union types, since comparison with a union
65-
// type on the right might commit too early into one side.
66-
case pt => pt.isRef(defn.UnitClass)
67-
}
62+
private def disregardProto(pt: Type)(implicit ctx: Context): Boolean =
63+
pt.dealias.isRef(defn.UnitClass)
6864

6965
/** Check that the result type of the current method
7066
* fits the given expected result type.
7167
*/
72-
def constrainResult(mt: Type, pt: Type)(implicit ctx: Context): Boolean = {
68+
def constrainResult(mt: Type, pt: Type)(implicit parentCtx: Context): Boolean = {
69+
given ctx as Context = parentCtx.addMode(Mode.ConstrainResult)
7370
val savedConstraint = ctx.typerState.constraint
7471
val res = pt.widenExpr match {
7572
case pt: FunProto =>

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class CompilationTests extends ParallelTesting {
137137
compileFile("tests/neg-custom-args/i3882.scala", allowDeepSubtypes),
138138
compileFile("tests/neg-custom-args/i4372.scala", allowDeepSubtypes),
139139
compileFile("tests/neg-custom-args/i1754.scala", allowDeepSubtypes),
140+
compileFile("tests/neg-custom-args/interop-polytypes.scala", allowDeepSubtypes.and("-Yexplicit-nulls")),
140141
compileFile("tests/neg-custom-args/conditionalWarnings.scala", allowDeepSubtypes.and("-deprecation").and("-Xfatal-warnings")),
141142
compileFilesInDir("tests/neg-custom-args/isInstanceOf", allowDeepSubtypes and "-Xfatal-warnings"),
142143
compileFile("tests/neg-custom-args/i3627.scala", allowDeepSubtypes),

tests/neg/i6565.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ def (o: Lifted[O]) flatMap [O, U] (f: O => Lifted[U]): Lifted[U] = ???
99
val error: Err = Err()
1010

1111
lazy val ok: Lifted[String] = { // ok despite map returning a union
12-
point("a").map(_ => if true then "foo" else error) // error
12+
point("a").map(_ => if true then "foo" else error) // ok
1313
}
1414

1515
lazy val bad: Lifted[String] = { // found Lifted[Object]
1616
point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // error
17-
}
17+
}

tests/neg/union.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ object O {
1717

1818
val x: A = f(new A { }, new A)
1919

20-
val y1: A | B = f(new A { }, new B) // error
20+
val y1: A | B = f(new A { }, new B) // ok
2121
val y2: A | B = f[A | B](new A { }, new B) // ok
2222

2323
val z = if (???) new A{} else new B

tests/pos/and-inf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class A
2+
class B
3+
4+
class Inv[T]
5+
class Contra[-T]
6+
7+
class Test {
8+
def foo[T, S](x: T, y: S): Contra[Inv[T] & Inv[S]] = ???
9+
val a: A = new A
10+
val b: B = new B
11+
12+
val x: Contra[Inv[A] & Inv[B]] = foo(a, b)
13+
}

tests/pos/i7829.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
class X
2+
class Y
3+
4+
object Test {
5+
type Id[T] = T
6+
7+
val a: 1 = identity(1)
8+
val b: Id[1] = identity(1)
9+
10+
val c: X | Y = identity(if (true) new X else new Y)
11+
val d: Id[X | Y] = identity(if (true) new X else new Y)
12+
13+
def impUnion: Unit = {
14+
class Base
15+
class A extends Base
16+
class B extends Base
17+
class Inv[T]
18+
19+
implicit def invBase: Inv[Base] = new Inv[Base]
20+
21+
def getInv[T](x: T)(implicit inv: Inv[T]): Int = 1
22+
23+
val a: Int = getInv(if (true) new A else new B)
24+
// If we keep unions when doing the implicit search, this would give us: "no implicit argument of type Inv[X | Y]"
25+
val b: Int | Any = getInv(if (true) new A else new B)
26+
}
27+
}

tests/pos/i8378.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
trait Has[A]
2+
3+
trait A
4+
trait B
5+
trait C
6+
7+
trait ZLayer[-RIn, +E, +ROut]
8+
9+
object ZLayer {
10+
def fromServices[A0, A1, B](f: (A0, A1) => B): ZLayer[Has[A0] with Has[A1], Nothing, Has[B]] =
11+
???
12+
}
13+
14+
val live: ZLayer[Has[A] & Has[B], Nothing, Has[C]] =
15+
ZLayer.fromServices { (a: A, b: B) =>
16+
new C {}
17+
}

tests/pos/or-inf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
3+
def a(lis: Set[Int] | Set[String]) = {}
4+
a(Set(1))
5+
a(Set(""))
6+
7+
def b(lis: List[Set[Int] | Set[String]]) = {}
8+
b(List(Set(1)))
9+
b(List(Set("")))
10+
11+
def c(x: Set[Any] | Array[Any]) = {}
12+
c(Set(1))
13+
c(Array(1))
14+
}

tests/pos/orinf.scala

Lines changed: 0 additions & 6 deletions
This file was deleted.

0 commit comments

Comments
 (0)