diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 62c60fbf93c0..0f61fd2e25fe 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -484,9 +484,10 @@ trait ConstraintHandling[AbstractContext] { * recording an isLess relationship instead (even though this is not implied * by the bound). * - * Narrowing a constraint is better than widening it, because narrowing leads - * to incompleteness (which we face anyway, see for instance eitherIsSubType) - * but widening leads to unsoundness. + * Normally, narrowing a constraint is better than widening it, because + * narrowing leads to incompleteness (which we face anyway, see for + * instance `TypeComparer#either`) but widening leads to unsoundness, + * but note the special handling in `ConstrainResult` mode below. * * A test case that demonstrates the problem is i864.scala. * Turn Config.checkConstraintsSeparated on to get an accurate diagnostic @@ -544,10 +545,23 @@ trait ConstraintHandling[AbstractContext] { case bound: TypeParamRef if constraint contains bound => addParamBound(bound) case _ => + val savedConstraint = constraint val pbound = prune(bound) - pbound.exists - && kindCompatible(param, pbound) - && (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound)) + val constraintsNarrowed = constraint ne savedConstraint + + val res = + pbound.exists + && kindCompatible(param, pbound) + && (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound)) + // If we're in `ConstrainResult` mode, we don't want to commit to a + // set of constraints that would later prevent us from typechecking + // arguments, so if `pruneParams` had to narrow the constraints, we + // simply do not record any new constraint. + // Unlike in `TypeComparer#either`, the same reasoning does not apply + // to GADT mode because this code is never run on GADT constraints. + if ctx.mode.is(Mode.ConstrainResult) && constraintsNarrowed then + constraint = savedConstraint + res } finally addConstraintInvocations -= 1 } diff --git a/compiler/src/dotty/tools/dotc/core/Mode.scala b/compiler/src/dotty/tools/dotc/core/Mode.scala index bc49bd8ec2ed..f6a6c97c25e1 100644 --- a/compiler/src/dotty/tools/dotc/core/Mode.scala +++ b/compiler/src/dotty/tools/dotc/core/Mode.scala @@ -60,6 +60,9 @@ object Mode { */ val Printing: Mode = newMode(10, "Printing") + /** We are constraining a method based on its expected type. */ + val ConstrainResult: Mode = newMode(11, "ConstrainResult") + /** We are currently in a `viewExists` check. In that case, ambiguous * implicits checks are disabled and we succeed with the first implicit * found. diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index f06928a57ce8..62fae9637151 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1364,14 +1364,26 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w /** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints. * - * If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints. - * If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all. + * If we're inferring GADT bounds or constraining a method based on its + * expected type, we infer only the _necessary_ constraints, this means we + * keep the smaller constraint if any, or no constraint at all. This is + * necessary for GADT bounds inference to be sound. When constraining a + * method, this avoid committing of constraints that would later prevent us + * from typechecking method arguments, see or-inf.scala and and-inf.scala for + * examples. * + * Otherwise, we infer _sufficient_ constraints: we try to keep the smaller of + * the two constraints, but if never is smaller than the other, we just pick + * the first one. + * + * @see [[necessaryEither]] for the GADT / result type case * @see [[sufficientEither]] for the normal case - * @see [[necessaryEither]] for the GADTFlexible case */ protected def either(op1: => Boolean, op2: => Boolean): Boolean = - if (ctx.mode.is(Mode.GadtConstraintInference)) necessaryEither(op1, op2) else sufficientEither(op1, op2) + if ctx.mode.is(Mode.GadtConstraintInference) || ctx.mode.is(Mode.ConstrainResult) then + necessaryEither(op1, op2) + else + sufficientEither(op1, op2) /** Returns true iff the result of evaluating either `op1` or `op2` is true, * trying at the same time to keep the constraint as wide as possible. @@ -1438,8 +1450,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w * T1 & T2 <:< T3 * T1 <:< T2 | T3 * - * Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT - * constraints that necessarily follow from the subtyping relationship. For instance, if we have + * Unlike [[sufficientEither]], this method is used in GADTConstraintInference mode, when we are attempting + * to infer GADT constraints that necessarily follow from the subtyping relationship. For instance, if we have * * enum Expr[T] { * case IntExpr(i: Int) extends Expr[Int] @@ -1466,48 +1478,49 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w * * then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive. * + * This method is also used in ConstrainResult mode + * to avoid inference getting stuck due to lack of backtracking, + * see or-inf.scala and and-inf.scala for examples. + * * Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both * subtyping relationships. */ - private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = { + private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = val preConstraint = constraint - val preGadt = ctx.gadt.fresh - // if GADTflexible mode is on, we expect to always have a ProperGadtConstraint - val pre = preGadt.asInstanceOf[ProperGadtConstraint] - if (op1) { - val leftConstraint = constraint - val leftGadt = ctx.gadt.fresh + + def allSubsumes(leftGadt: GadtConstraint, rightGadt: GadtConstraint, left: Constraint, right: Constraint): Boolean = + subsumes(left, right, preConstraint) && preGadt.match + case preGadt: ProperGadtConstraint => + preGadt.subsumes(leftGadt, rightGadt, preGadt) + case _ => + true + + if op1 then + val op1Constraint = constraint + val op1Gadt = ctx.gadt.fresh constraint = preConstraint ctx.gadt.restore(preGadt) - if (op2) - if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) { - gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt") - constr.println(i"CUT - prefer $constraint over $leftConstraint") - true - } - else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) { - gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}") - constr.println(i"CUT - prefer $leftConstraint over $constraint") - constraint = leftConstraint - ctx.gadt.restore(leftGadt) - true - } - else { + if op2 then + if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then + gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt") + constr.println(i"CUT - prefer $constraint over $op1Constraint") + else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then + gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}") + constr.println(i"CUT - prefer $op1Constraint over $constraint") + constraint = op1Constraint + ctx.gadt.restore(op1Gadt) + else gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt") constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint") constraint = preConstraint ctx.gadt.restore(preGadt) - true - } - else { - constraint = leftConstraint - ctx.gadt.restore(leftGadt) - true - } - } + else + constraint = op1Constraint + ctx.gadt.restore(op1Gadt) + true else op2 - } + end necessaryEither /** Does type `tp1` have a member with name `name` whose normalized type is a subtype of * the normalized type of the refinement `tp2`? diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index e99a86088a8e..418b3538d34a 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -59,17 +59,14 @@ object ProtoTypes { else ctx.test(testCompat) } - private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = pt.dealias match { - case _: OrType => true - // Don't constrain results with union types, since comparison with a union - // type on the right might commit too early into one side. - case pt => pt.isRef(defn.UnitClass) - } + private def disregardProto(pt: Type)(implicit ctx: Context): Boolean = + pt.dealias.isRef(defn.UnitClass) /** Check that the result type of the current method * fits the given expected result type. */ - def constrainResult(mt: Type, pt: Type)(implicit ctx: Context): Boolean = { + def constrainResult(mt: Type, pt: Type)(implicit parentCtx: Context): Boolean = { + given ctx as Context = parentCtx.addMode(Mode.ConstrainResult) val savedConstraint = ctx.typerState.constraint val res = pt.widenExpr match { case pt: FunProto => diff --git a/compiler/test/dotty/tools/dotc/CompilationTests.scala b/compiler/test/dotty/tools/dotc/CompilationTests.scala index e8f2df1a8802..8cccc15d3e85 100644 --- a/compiler/test/dotty/tools/dotc/CompilationTests.scala +++ b/compiler/test/dotty/tools/dotc/CompilationTests.scala @@ -137,6 +137,7 @@ class CompilationTests extends ParallelTesting { compileFile("tests/neg-custom-args/i3882.scala", allowDeepSubtypes), compileFile("tests/neg-custom-args/i4372.scala", allowDeepSubtypes), compileFile("tests/neg-custom-args/i1754.scala", allowDeepSubtypes), + compileFile("tests/neg-custom-args/interop-polytypes.scala", allowDeepSubtypes.and("-Yexplicit-nulls")), compileFile("tests/neg-custom-args/conditionalWarnings.scala", allowDeepSubtypes.and("-deprecation").and("-Xfatal-warnings")), compileFilesInDir("tests/neg-custom-args/isInstanceOf", allowDeepSubtypes and "-Xfatal-warnings"), compileFile("tests/neg-custom-args/i3627.scala", allowDeepSubtypes), diff --git a/tests/explicit-nulls/neg/interop-polytypes.scala b/tests/neg-custom-args/interop-polytypes.scala similarity index 100% rename from tests/explicit-nulls/neg/interop-polytypes.scala rename to tests/neg-custom-args/interop-polytypes.scala diff --git a/tests/neg/i6565.scala b/tests/neg/i6565.scala index a51eeb24c308..d5fab12842d3 100644 --- a/tests/neg/i6565.scala +++ b/tests/neg/i6565.scala @@ -9,9 +9,9 @@ def (o: Lifted[O]) flatMap [O, U] (f: O => Lifted[U]): Lifted[U] = ??? val error: Err = Err() lazy val ok: Lifted[String] = { // ok despite map returning a union - point("a").map(_ => if true then "foo" else error) // error + point("a").map(_ => if true then "foo" else error) // ok } lazy val bad: Lifted[String] = { // found Lifted[Object] point("a").flatMap(_ => point("b").map(_ => if true then "foo" else error)) // error -} \ No newline at end of file +} diff --git a/tests/neg/union.scala b/tests/neg/union.scala index c594e83d74bc..0a702ab70058 100644 --- a/tests/neg/union.scala +++ b/tests/neg/union.scala @@ -17,7 +17,7 @@ object O { val x: A = f(new A { }, new A) - val y1: A | B = f(new A { }, new B) // error + val y1: A | B = f(new A { }, new B) // ok val y2: A | B = f[A | B](new A { }, new B) // ok val z = if (???) new A{} else new B diff --git a/tests/pos/and-inf.scala b/tests/pos/and-inf.scala new file mode 100644 index 000000000000..3008014a00a9 --- /dev/null +++ b/tests/pos/and-inf.scala @@ -0,0 +1,13 @@ +class A +class B + +class Inv[T] +class Contra[-T] + +class Test { + def foo[T, S](x: T, y: S): Contra[Inv[T] & Inv[S]] = ??? + val a: A = new A + val b: B = new B + + val x: Contra[Inv[A] & Inv[B]] = foo(a, b) +} diff --git a/tests/pos/i7829.scala b/tests/pos/i7829.scala new file mode 100644 index 000000000000..2f3d71366b7c --- /dev/null +++ b/tests/pos/i7829.scala @@ -0,0 +1,27 @@ +class X +class Y + +object Test { + type Id[T] = T + + val a: 1 = identity(1) + val b: Id[1] = identity(1) + + val c: X | Y = identity(if (true) new X else new Y) + val d: Id[X | Y] = identity(if (true) new X else new Y) + + def impUnion: Unit = { + class Base + class A extends Base + class B extends Base + class Inv[T] + + implicit def invBase: Inv[Base] = new Inv[Base] + + def getInv[T](x: T)(implicit inv: Inv[T]): Int = 1 + + val a: Int = getInv(if (true) new A else new B) + // If we keep unions when doing the implicit search, this would give us: "no implicit argument of type Inv[X | Y]" + val b: Int | Any = getInv(if (true) new A else new B) + } +} diff --git a/tests/pos/i8378.scala b/tests/pos/i8378.scala new file mode 100644 index 000000000000..b69fec928c76 --- /dev/null +++ b/tests/pos/i8378.scala @@ -0,0 +1,17 @@ +trait Has[A] + +trait A +trait B +trait C + +trait ZLayer[-RIn, +E, +ROut] + +object ZLayer { + def fromServices[A0, A1, B](f: (A0, A1) => B): ZLayer[Has[A0] with Has[A1], Nothing, Has[B]] = + ??? +} + +val live: ZLayer[Has[A] & Has[B], Nothing, Has[C]] = + ZLayer.fromServices { (a: A, b: B) => + new C {} + } diff --git a/tests/pos/or-inf.scala b/tests/pos/or-inf.scala new file mode 100644 index 000000000000..e6022b888e14 --- /dev/null +++ b/tests/pos/or-inf.scala @@ -0,0 +1,14 @@ +object Test { + + def a(lis: Set[Int] | Set[String]) = {} + a(Set(1)) + a(Set("")) + + def b(lis: List[Set[Int] | Set[String]]) = {} + b(List(Set(1))) + b(List(Set(""))) + + def c(x: Set[Any] | Array[Any]) = {} + c(Set(1)) + c(Array(1)) +} diff --git a/tests/pos/orinf.scala b/tests/pos/orinf.scala deleted file mode 100644 index 30b7fd2f6353..000000000000 --- a/tests/pos/orinf.scala +++ /dev/null @@ -1,6 +0,0 @@ -object Test { - - def foo(lis: scala.collection.immutable.Set[Int] | scala.collection.immutable.Set[String]) = lis - foo(Set(1)) - foo(Set("")) -}