From d47605d29d47d1394378a2986e03a1c5ec4f813f Mon Sep 17 00:00:00 2001 From: Dale Wijnand Date: Mon, 11 Dec 2023 18:29:01 +0000 Subject: [PATCH] Fix type inferencing (constraining) regressions --- .../dotc/core/PatternTypeConstrainer.scala | 39 ++++++++----------- .../dotty/tools/dotc/core/TypeComparer.scala | 9 ++++- .../src/dotty/tools/dotc/typer/Typer.scala | 6 ++- .../dotc/semanticdb/SemanticdbTests.scala | 2 +- tests/neg/i18453.min.scala | 11 ++++++ tests/{pos => neg}/i18453.scala | 6 ++- tests/neg/i5976.scala | 2 +- tests/pos/i18453.zio.scala | 33 ++++++++++++++++ tests/pos/i19001.case1.scala | 19 +++++++++ tests/pos/i19001.case2.scala | 16 ++++++++ tests/pos/i19001.case3.scala | 12 ++++++ tests/pos/i19009.case1.scala | 18 +++++++++ tests/pos/i19009.case2.scala | 10 +++++ tests/pos/i19009.case3.scala | 31 +++++++++++++++ tests/pos/i19009.min3.scala | 9 +++++ tests/semanticdb/metac.expect | 5 ++- 16 files changed, 198 insertions(+), 30 deletions(-) create mode 100644 tests/neg/i18453.min.scala rename tests/{pos => neg}/i18453.scala (53%) create mode 100644 tests/pos/i18453.zio.scala create mode 100644 tests/pos/i19001.case1.scala create mode 100644 tests/pos/i19001.case2.scala create mode 100644 tests/pos/i19001.case3.scala create mode 100644 tests/pos/i19009.case1.scala create mode 100644 tests/pos/i19009.case2.scala create mode 100644 tests/pos/i19009.case3.scala create mode 100644 tests/pos/i19009.min3.scala diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 4e3596ea8814..38f8e19e2737 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -263,29 +263,22 @@ trait PatternTypeConstrainer { self: TypeComparer => trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") { (tp, pt) match { - case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) => - val saved = state.nn.constraint - val result = - ctx.gadtState.rollbackGadtUnless { - tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) => - val variance = param.paramVarianceSign - if variance == 0 || assumeInvariantRefinement || - // As a special case, when pattern and scrutinee types have the same type constructor, - // we infer better bounds for pattern-bound abstract types. - argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol - then - val TypeBounds(loS, hiS) = argS.bounds - val TypeBounds(loP, hiP) = argP.bounds - var res = true - if variance < 1 then res &&= isSubType(loS, hiP) - if variance > -1 then res &&= isSubType(loP, hiS) - res - else true - } - } - if !result then - constraint = saved - result + case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) => rollbackConstraintsUnless: + tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) => + val variance = param.paramVarianceSign + if variance == 0 || assumeInvariantRefinement || + // As a special case, when pattern and scrutinee types have the same type constructor, + // we infer better bounds for pattern-bound abstract types. + argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol + then + val TypeBounds(loS, hiS) = argS.bounds + val TypeBounds(loP, hiP) = argP.bounds + var res = true + if variance < 1 then res &&= isSubType(loS, hiP) + if variance > -1 then res &&= isSubType(loP, hiS) + res + else true + } case _ => // Give up if we don't get AppliedType, e.g. if we upcasted to Any. // Note that this doesn't mean that patternTp, scrutineeTp cannot possibly diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 7c1b6570b0e8..931196c0065a 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1010,7 +1010,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling if (tp2a ne tp2) // Follow the alias; this might avoid truncating the search space in the either below return recur(tp1, tp2a) - // Rewrite (T111 | T112) & T12 <: T2 to (T111 & T12) <: T2 and (T112 | T12) <: T2 + // Rewrite (T111 | T112) & T12 <: T2 to (T111 & T12) <: T2 and (T112 & T12) <: T2 // and analogously for T11 & (T121 | T122) & T12 <: T2 // `&' types to the left of <: are problematic, because // we have to choose one constraint set or another, which might cut off @@ -1982,6 +1982,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else op2 end necessaryEither + inline def rollbackConstraintsUnless(inline op: Boolean): Boolean = + val saved = constraint + var result = false + try result = ctx.gadtState.rollbackGadtUnless(op) + finally if !result then constraint = saved + result + /** Decompose into conjunction of types each of which has only a single refinement */ def decomposeRefinements(tp: Type, refines: List[(Name, Type)]): Type = tp match case RefinedType(parent, rname, rinfo) => diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index a1ef6c0b2f25..2f03c79754e8 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -4194,7 +4194,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val funExpected = functionExpected val arity = if funExpected then - defn.functionArity(ptNorm) + if !isFullyDefined(pt, ForceDegree.none) && isFullyDefined(wtp, ForceDegree.none) then + // if method type is fully defined, but expected type is not, + // prioritize method parameter types as parameter types of the eta-expanded closure + 0 + else defn.functionArity(ptNorm) else val nparams = wtp.paramInfos.length if nparams > 1 diff --git a/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala b/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala index 8839a6cd03b1..4db047d0951e 100644 --- a/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala +++ b/compiler/test/dotty/tools/dotc/semanticdb/SemanticdbTests.scala @@ -102,7 +102,7 @@ class SemanticdbTests: |inspect with: | diff $expect ${expect.resolveSibling("" + expect.getFileName + ".out")} |Or else update all expect files with - | sbt 'scala3-compiler-bootstrapped/test:runMain dotty.tools.dotc.semanticdb.updateExpect'""".stripMargin) + | sbt 'scala3-compiler-bootstrapped/Test/runMain dotty.tools.dotc.semanticdb.updateExpect'""".stripMargin) Files.walk(target).sorted(Comparator.reverseOrder).forEach(Files.delete) if errors.nonEmpty then fail(s"${errors.size} errors in expect test.") diff --git a/tests/neg/i18453.min.scala b/tests/neg/i18453.min.scala new file mode 100644 index 000000000000..e63a818e8f71 --- /dev/null +++ b/tests/neg/i18453.min.scala @@ -0,0 +1,11 @@ +// Slightly nicer version of i18453 +// which uses a non-abstract type Foo instead +trait Box[T] + +trait Foo + +class Test: + def meth[A](func: A => A & Foo)(using boxA: Box[A]): Unit = ??? + def test[B] (using boxB: Box[B]): Unit = + def nest(p: B): B & Foo = ??? + meth(nest) // error diff --git a/tests/pos/i18453.scala b/tests/neg/i18453.scala similarity index 53% rename from tests/pos/i18453.scala rename to tests/neg/i18453.scala index 40dd14935a10..9a865b420f65 100644 --- a/tests/pos/i18453.scala +++ b/tests/neg/i18453.scala @@ -1,3 +1,7 @@ +// Would be nice if this compiled +// but it doesn't +// because of how we constrain `A` +// and then try to "minimise" its instantiation trait Box[T] class Test: @@ -5,4 +9,4 @@ class Test: def g[X, Y](using bx: Box[X]): Unit = def d(t: X): X & Y = t.asInstanceOf[X & Y] - f(d) + f(d) // error diff --git a/tests/neg/i5976.scala b/tests/neg/i5976.scala index ef2e743e39fe..0b037f50a4ea 100644 --- a/tests/neg/i5976.scala +++ b/tests/neg/i5976.scala @@ -1,6 +1,6 @@ object Test { def f(i: => Int) = i + i - val res = List(42).map(f) + val res = List(42).map(f) // error val g: (=> Int) => Int = f val h: Int => Int = g // error diff --git a/tests/pos/i18453.zio.scala b/tests/pos/i18453.zio.scala new file mode 100644 index 000000000000..32a9ebd0321c --- /dev/null +++ b/tests/pos/i18453.zio.scala @@ -0,0 +1,33 @@ +// Minimised from zio's ZLayer ++ + +// In an attempt to fix i18453 +// this would break zio's ZLayer +// in the "would-error" cases +class Cov[+W]: + def add[X >: W, Y](y: Cov[Y]): Cov[X & Y] = ??? + def pre[Y >: W, X](x: Cov[X]): Cov[X & Y] = ??? + +class Test: + def a1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & B & C] = a.add(b).add(c) + def a2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A with B with C] = a.add(b).add(c) // would-error + + def b1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.add(b).add(c) // would-error (a2) + def b2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.add(b).add(c) + def b3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.add(b.add(c)) + def b4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.add(b.add(c)) + + + def c3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & B & C] = a.pre(b).pre(c) + def c4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A with B with C] = a.pre(b).pre(c) // would-error + + def d1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.pre(b).pre(c) // would-error (c4) + def d2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.pre(b).pre(c) + def d3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.pre(b.pre(c)) + def d4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.pre(b.pre(c)) + + + def add[X, Y](x: Cov[X], y: Cov[Y]): Cov[X & Y] = ??? + def e1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = add(add(a, b), c) // alt assoc: ok! + def e2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = add(add(a, b), c) // reg assoc: ok + def e3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = add(a, add(b, c)) // reg assoc: ok + def e4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = add(a, add(b, c)) // alt assoc: ok! diff --git a/tests/pos/i19001.case1.scala b/tests/pos/i19001.case1.scala new file mode 100644 index 000000000000..3e1a67caf308 --- /dev/null +++ b/tests/pos/i19001.case1.scala @@ -0,0 +1,19 @@ +import java.util.concurrent.CompletionStage +import scala.concurrent.Future + +trait ActorRef[-T]: + def ask[Res](replyTo: ActorRef[Res] => T): Future[Res] = ??? + +implicit final class FutureOps[T](private val f: Future[T]) extends AnyVal: + def asJava: CompletionStage[T] = ??? + +class AskPattern[Req, Res]: + val actor: ActorRef[Req] = ??? + val messageFactory: ActorRef[Res] => Req = ??? + + def failing(): CompletionStage[Res] = actor.ask(messageFactory.apply).asJava + def workaround1(): CompletionStage[Res] = actor.ask[Res](messageFactory.apply).asJava + def workaround2(): CompletionStage[Res] = actor.ask(messageFactory).asJava + + val jMessageFactory: java.util.function.Function[ActorRef[Res], Req] = ??? + def originalFailingCase(): CompletionStage[Res] = actor.ask(jMessageFactory.apply).asJava diff --git a/tests/pos/i19001.case2.scala b/tests/pos/i19001.case2.scala new file mode 100644 index 000000000000..547441c58ff2 --- /dev/null +++ b/tests/pos/i19001.case2.scala @@ -0,0 +1,16 @@ +import scala.util.{Try, Success, Failure} + +trait ActorRef[-T] +trait ActorContext[T]: + def ask[Req, Res](target: ActorRef[Req], createRequest: ActorRef[Res] => Req)(mapResponse: Try[Res] => T): Unit + +@main def Test = + val context: ActorContext[Int] = ??? + val askMeRef: ActorRef[Request] = ??? + + case class Request(replyTo: ActorRef[Int]) + + context.ask(askMeRef, Request.apply) { + case Success(res) => res // error: expected Int, got Any + case Failure(ex) => throw ex + } diff --git a/tests/pos/i19001.case3.scala b/tests/pos/i19001.case3.scala new file mode 100644 index 000000000000..cc3f8e558fc8 --- /dev/null +++ b/tests/pos/i19001.case3.scala @@ -0,0 +1,12 @@ +trait IO[A]: + def map[B](f: A => B): IO[B] = ??? + +trait RenderResult[T]: + def value: T + +def IOasync[T](f: (Either[Throwable, T] => Unit) => Unit): IO[T] = ??? + +def render[T]: IO[T] = { + def register(cb: Either[Throwable, RenderResult[T]] => Unit): Unit = ??? + IOasync(register).map(_.value) // map should take RenderResult[T], but uses Any +} diff --git a/tests/pos/i19009.case1.scala b/tests/pos/i19009.case1.scala new file mode 100644 index 000000000000..84738dcf384a --- /dev/null +++ b/tests/pos/i19009.case1.scala @@ -0,0 +1,18 @@ +trait Player[+P] +trait RatingPeriod[P]: + def games: Map[P, Vector[ScoreVsPlayer[P]]] + +trait ScoreVsPlayer[+P] + +def updated[P](playerID: P, matchResults: IndexedSeq[ScoreVsPlayer[P]], lookup: P => Option[Player[P]]): Player[P] = ??? + +trait Leaderboard[P]: + def playersByIdInNoParticularOrder: Map[P, Player[P]] + + def after[P2 >: P](ratingPeriod: RatingPeriod[? <: P]): Leaderboard[P2] = + val competingPlayers = ratingPeriod.games.iterator.map { (id, matchResults) => + updated(id, matchResults, playersByIdInNoParticularOrder.get) // error + // workaround: + updated[P](id, matchResults, playersByIdInNoParticularOrder.get) + } + ??? diff --git a/tests/pos/i19009.case2.scala b/tests/pos/i19009.case2.scala new file mode 100644 index 000000000000..8c395aa48e46 --- /dev/null +++ b/tests/pos/i19009.case2.scala @@ -0,0 +1,10 @@ +object NodeOrdering: + def postOrderNumbering[NodeType](cfgEntry: NodeType, expand: NodeType => Iterator[NodeType]): Map[NodeType, Int] = ??? + +trait CfgNode +trait Method extends CfgNode + +def postOrder = + def method: Method = ??? + def expand(x: CfgNode): Iterator[CfgNode] = ??? + NodeOrdering.postOrderNumbering(method, expand) diff --git a/tests/pos/i19009.case3.scala b/tests/pos/i19009.case3.scala new file mode 100644 index 000000000000..b2b17b312af0 --- /dev/null +++ b/tests/pos/i19009.case3.scala @@ -0,0 +1,31 @@ +trait Bound[+E] + +trait SegmentT[E, +S] +object SegmentT: + trait WithPrev[E, +S] extends SegmentT[E, S] + +trait SegmentSeqT[E, +S]: + def getSegmentForBound(bound: Bound[E]): SegmentT[E, S] with S + +abstract class AbstractSegmentSeq[E, +S] extends SegmentSeqT[E, S] + +trait MappedSegmentBase[E, S] + +type MappedSegment[E, S] = AbstractMappedSegmentSeq.MappedSegment[E, S] + +object AbstractMappedSegmentSeq: + type MappedSegment[E, S] = SegmentT[E, MappedSegmentBase[E, S]] with MappedSegmentBase[E, S] + +abstract class AbstractMappedSegmentSeq[E, S] + extends AbstractSegmentSeq[E, MappedSegmentBase[E, S]]: + def originalSeq: SegmentSeqT[E, S] + + final override def getSegmentForBound(bound: Bound[E]): MappedSegment[E, S] = + searchFrontMapper(frontMapperGeneral, originalSeq.getSegmentForBound(bound)) + + protected final def frontMapperGeneral(original: SegmentT[E, S]): MappedSegment[E, S] = ??? + + protected def searchFrontMapper[Seg >: SegmentT.WithPrev[E, S] <: SegmentT[E, S], R]( + mapper: Seg => R, + original: Seg + ): R = ??? diff --git a/tests/pos/i19009.min3.scala b/tests/pos/i19009.min3.scala new file mode 100644 index 000000000000..f59a4485b219 --- /dev/null +++ b/tests/pos/i19009.min3.scala @@ -0,0 +1,9 @@ +trait Foo[A] +trait Bar[B] extends Foo[B] + +class Test[C]: + def put[X >: Bar[C]](fn: X => Unit, x1: X): Unit = () + def id(foo: Foo[C]): Foo[C] = foo + + def t1(foo2: Foo[C]): Unit = + put(id, foo2) // was: error: exp: Bar[C], got (foo2 : Foo[C]) diff --git a/tests/semanticdb/metac.expect b/tests/semanticdb/metac.expect index 2fd8eca47a7b..c8bb3d3e4ec4 100644 --- a/tests/semanticdb/metac.expect +++ b/tests/semanticdb/metac.expect @@ -1094,7 +1094,7 @@ Language => Scala Symbols => 181 entries Occurrences => 159 entries Diagnostics => 1 entries -Synthetics => 5 entries +Synthetics => 6 entries Symbols: _empty_/Enums. => final object Enums extends Object { self: Enums.type => +30 decls } @@ -1277,7 +1277,7 @@ _empty_/Enums.unwrap().(ev) => implicit given param ev: <:<[A, Option[B]] _empty_/Enums.unwrap().(opt) => param opt: Option[A] _empty_/Enums.unwrap().[A] => typeparam A _empty_/Enums.unwrap().[B] => typeparam B -local0 => param x: A +local0 => param x: Option[B] Occurrences: [0:7..0:12): Enums <- _empty_/Enums. @@ -1445,6 +1445,7 @@ Diagnostics: Synthetics: [52:9..52:13):Refl => *.unapply[Option[B]] +[52:31..52:50):identity[Option[B]] => *[Function1[A, Option[B]]] [54:14..54:18):Some => *.apply[Some[Int]] [54:14..54:34):Some(Some(1)).unwrap => *(given_<:<_T_T[Option[Int]]) [54:19..54:23):Some => *.apply[Int]