Skip to content

Fix type inferencing (constraining) regressions #19189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,29 +263,22 @@ trait PatternTypeConstrainer { self: TypeComparer =>

trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
(tp, pt) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
val saved = state.nn.constraint
val result =
ctx.gadtState.rollbackGadtUnless {
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance == 0 || assumeInvariantRefinement ||
// As a special case, when pattern and scrutinee types have the same type constructor,
// we infer better bounds for pattern-bound abstract types.
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
then
val TypeBounds(loS, hiS) = argS.bounds
val TypeBounds(loP, hiP) = argP.bounds
var res = true
if variance < 1 then res &&= isSubType(loS, hiP)
if variance > -1 then res &&= isSubType(loP, hiS)
res
else true
}
}
if !result then
constraint = saved
result
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) => rollbackConstraintsUnless:
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance == 0 || assumeInvariantRefinement ||
// As a special case, when pattern and scrutinee types have the same type constructor,
// we infer better bounds for pattern-bound abstract types.
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
then
val TypeBounds(loS, hiS) = argS.bounds
val TypeBounds(loP, hiP) = argP.bounds
var res = true
if variance < 1 then res &&= isSubType(loS, hiP)
if variance > -1 then res &&= isSubType(loP, hiS)
res
else true
}
case _ =>
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.
// Note that this doesn't mean that patternTp, scrutineeTp cannot possibly
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
if (tp2a ne tp2) // Follow the alias; this might avoid truncating the search space in the either below
return recur(tp1, tp2a)

// Rewrite (T111 | T112) & T12 <: T2 to (T111 & T12) <: T2 and (T112 | T12) <: T2
// Rewrite (T111 | T112) & T12 <: T2 to (T111 & T12) <: T2 and (T112 & T12) <: T2
// and analogously for T11 & (T121 | T122) & T12 <: T2
// `&' types to the left of <: are problematic, because
// we have to choose one constraint set or another, which might cut off
Expand Down Expand Up @@ -1982,6 +1982,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else op2
end necessaryEither

inline def rollbackConstraintsUnless(inline op: Boolean): Boolean =
val saved = constraint
var result = false
try result = ctx.gadtState.rollbackGadtUnless(op)
finally if !result then constraint = saved
result

/** Decompose into conjunction of types each of which has only a single refinement */
def decomposeRefinements(tp: Type, refines: List[(Name, Type)]): Type = tp match
case RefinedType(parent, rname, rinfo) =>
Expand Down
6 changes: 5 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4194,7 +4194,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val funExpected = functionExpected
val arity =
if funExpected then
defn.functionArity(ptNorm)
if !isFullyDefined(pt, ForceDegree.none) && isFullyDefined(wtp, ForceDegree.none) then
// if method type is fully defined, but expected type is not,
// prioritize method parameter types as parameter types of the eta-expanded closure
0
else defn.functionArity(ptNorm)
else
val nparams = wtp.paramInfos.length
if nparams > 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class SemanticdbTests:
|inspect with:
| diff $expect ${expect.resolveSibling("" + expect.getFileName + ".out")}
|Or else update all expect files with
| sbt 'scala3-compiler-bootstrapped/test:runMain dotty.tools.dotc.semanticdb.updateExpect'""".stripMargin)
| sbt 'scala3-compiler-bootstrapped/Test/runMain dotty.tools.dotc.semanticdb.updateExpect'""".stripMargin)
Files.walk(target).sorted(Comparator.reverseOrder).forEach(Files.delete)
if errors.nonEmpty then
fail(s"${errors.size} errors in expect test.")
Expand Down
11 changes: 11 additions & 0 deletions tests/neg/i18453.min.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Slightly nicer version of i18453
// which uses a non-abstract type Foo instead
trait Box[T]

trait Foo

class Test:
def meth[A](func: A => A & Foo)(using boxA: Box[A]): Unit = ???
def test[B] (using boxB: Box[B]): Unit =
def nest(p: B): B & Foo = ???
meth(nest) // error
6 changes: 5 additions & 1 deletion tests/pos/i18453.scala → tests/neg/i18453.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// Would be nice if this compiled
// but it doesn't
// because of how we constrain `A`
// and then try to "minimise" its instantiation
trait Box[T]

class Test:
def f[A, B](c: A => A & B)(using ba: Box[A]): Unit = ???

def g[X, Y](using bx: Box[X]): Unit =
def d(t: X): X & Y = t.asInstanceOf[X & Y]
f(d)
f(d) // error
2 changes: 1 addition & 1 deletion tests/neg/i5976.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
object Test {
def f(i: => Int) = i + i
val res = List(42).map(f)
val res = List(42).map(f) // error

val g: (=> Int) => Int = f
val h: Int => Int = g // error
Expand Down
33 changes: 33 additions & 0 deletions tests/pos/i18453.zio.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Minimised from zio's ZLayer ++

// In an attempt to fix i18453
// this would break zio's ZLayer
// in the "would-error" cases
class Cov[+W]:
def add[X >: W, Y](y: Cov[Y]): Cov[X & Y] = ???
def pre[Y >: W, X](x: Cov[X]): Cov[X & Y] = ???

class Test:
def a1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & B & C] = a.add(b).add(c)
def a2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A with B with C] = a.add(b).add(c) // would-error

def b1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.add(b).add(c) // would-error (a2)
def b2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.add(b).add(c)
def b3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.add(b.add(c))
def b4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.add(b.add(c))


def c3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & B & C] = a.pre(b).pre(c)
def c4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A with B with C] = a.pre(b).pre(c) // would-error

def d1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.pre(b).pre(c) // would-error (c4)
def d2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.pre(b).pre(c)
def d3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = a.pre(b.pre(c))
def d4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = a.pre(b.pre(c))


def add[X, Y](x: Cov[X], y: Cov[Y]): Cov[X & Y] = ???
def e1[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = add(add(a, b), c) // alt assoc: ok!
def e2[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = add(add(a, b), c) // reg assoc: ok
def e3[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[A & (B & C)] = add(a, add(b, c)) // reg assoc: ok
def e4[A, B, C](a: Cov[A], b: Cov[B], c: Cov[C]): Cov[(A & B) & C] = add(a, add(b, c)) // alt assoc: ok!
19 changes: 19 additions & 0 deletions tests/pos/i19001.case1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import java.util.concurrent.CompletionStage
import scala.concurrent.Future

trait ActorRef[-T]:
def ask[Res](replyTo: ActorRef[Res] => T): Future[Res] = ???

implicit final class FutureOps[T](private val f: Future[T]) extends AnyVal:
def asJava: CompletionStage[T] = ???

class AskPattern[Req, Res]:
val actor: ActorRef[Req] = ???
val messageFactory: ActorRef[Res] => Req = ???

def failing(): CompletionStage[Res] = actor.ask(messageFactory.apply).asJava
def workaround1(): CompletionStage[Res] = actor.ask[Res](messageFactory.apply).asJava
def workaround2(): CompletionStage[Res] = actor.ask(messageFactory).asJava

val jMessageFactory: java.util.function.Function[ActorRef[Res], Req] = ???
def originalFailingCase(): CompletionStage[Res] = actor.ask(jMessageFactory.apply).asJava
16 changes: 16 additions & 0 deletions tests/pos/i19001.case2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.util.{Try, Success, Failure}

trait ActorRef[-T]
trait ActorContext[T]:
def ask[Req, Res](target: ActorRef[Req], createRequest: ActorRef[Res] => Req)(mapResponse: Try[Res] => T): Unit

@main def Test =
val context: ActorContext[Int] = ???
val askMeRef: ActorRef[Request] = ???

case class Request(replyTo: ActorRef[Int])

context.ask(askMeRef, Request.apply) {
case Success(res) => res // error: expected Int, got Any
case Failure(ex) => throw ex
}
12 changes: 12 additions & 0 deletions tests/pos/i19001.case3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
trait IO[A]:
def map[B](f: A => B): IO[B] = ???

trait RenderResult[T]:
def value: T

def IOasync[T](f: (Either[Throwable, T] => Unit) => Unit): IO[T] = ???

def render[T]: IO[T] = {
def register(cb: Either[Throwable, RenderResult[T]] => Unit): Unit = ???
IOasync(register).map(_.value) // map should take RenderResult[T], but uses Any
}
18 changes: 18 additions & 0 deletions tests/pos/i19009.case1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
trait Player[+P]
trait RatingPeriod[P]:
def games: Map[P, Vector[ScoreVsPlayer[P]]]

trait ScoreVsPlayer[+P]

def updated[P](playerID: P, matchResults: IndexedSeq[ScoreVsPlayer[P]], lookup: P => Option[Player[P]]): Player[P] = ???

trait Leaderboard[P]:
def playersByIdInNoParticularOrder: Map[P, Player[P]]

def after[P2 >: P](ratingPeriod: RatingPeriod[? <: P]): Leaderboard[P2] =
val competingPlayers = ratingPeriod.games.iterator.map { (id, matchResults) =>
updated(id, matchResults, playersByIdInNoParticularOrder.get) // error
// workaround:
updated[P](id, matchResults, playersByIdInNoParticularOrder.get)
}
???
10 changes: 10 additions & 0 deletions tests/pos/i19009.case2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object NodeOrdering:
def postOrderNumbering[NodeType](cfgEntry: NodeType, expand: NodeType => Iterator[NodeType]): Map[NodeType, Int] = ???

trait CfgNode
trait Method extends CfgNode

def postOrder =
def method: Method = ???
def expand(x: CfgNode): Iterator[CfgNode] = ???
NodeOrdering.postOrderNumbering(method, expand)
31 changes: 31 additions & 0 deletions tests/pos/i19009.case3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
trait Bound[+E]

trait SegmentT[E, +S]
object SegmentT:
trait WithPrev[E, +S] extends SegmentT[E, S]

trait SegmentSeqT[E, +S]:
def getSegmentForBound(bound: Bound[E]): SegmentT[E, S] with S

abstract class AbstractSegmentSeq[E, +S] extends SegmentSeqT[E, S]

trait MappedSegmentBase[E, S]

type MappedSegment[E, S] = AbstractMappedSegmentSeq.MappedSegment[E, S]

object AbstractMappedSegmentSeq:
type MappedSegment[E, S] = SegmentT[E, MappedSegmentBase[E, S]] with MappedSegmentBase[E, S]

abstract class AbstractMappedSegmentSeq[E, S]
extends AbstractSegmentSeq[E, MappedSegmentBase[E, S]]:
def originalSeq: SegmentSeqT[E, S]

final override def getSegmentForBound(bound: Bound[E]): MappedSegment[E, S] =
searchFrontMapper(frontMapperGeneral, originalSeq.getSegmentForBound(bound))

protected final def frontMapperGeneral(original: SegmentT[E, S]): MappedSegment[E, S] = ???

protected def searchFrontMapper[Seg >: SegmentT.WithPrev[E, S] <: SegmentT[E, S], R](
mapper: Seg => R,
original: Seg
): R = ???
9 changes: 9 additions & 0 deletions tests/pos/i19009.min3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
trait Foo[A]
trait Bar[B] extends Foo[B]

class Test[C]:
def put[X >: Bar[C]](fn: X => Unit, x1: X): Unit = ()
def id(foo: Foo[C]): Foo[C] = foo

def t1(foo2: Foo[C]): Unit =
put(id, foo2) // was: error: exp: Bar[C], got (foo2 : Foo[C])
5 changes: 3 additions & 2 deletions tests/semanticdb/metac.expect
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ Language => Scala
Symbols => 181 entries
Occurrences => 159 entries
Diagnostics => 1 entries
Synthetics => 5 entries
Synthetics => 6 entries

Symbols:
_empty_/Enums. => final object Enums extends Object { self: Enums.type => +30 decls }
Expand Down Expand Up @@ -1277,7 +1277,7 @@ _empty_/Enums.unwrap().(ev) => implicit given param ev: <:<[A, Option[B]]
_empty_/Enums.unwrap().(opt) => param opt: Option[A]
_empty_/Enums.unwrap().[A] => typeparam A
_empty_/Enums.unwrap().[B] => typeparam B
local0 => param x: A
local0 => param x: Option[B]

Occurrences:
[0:7..0:12): Enums <- _empty_/Enums.
Expand Down Expand Up @@ -1445,6 +1445,7 @@ Diagnostics:

Synthetics:
[52:9..52:13):Refl => *.unapply[Option[B]]
[52:31..52:50):identity[Option[B]] => *[Function1[A, Option[B]]]
[54:14..54:18):Some => *.apply[Some[Int]]
[54:14..54:34):Some(Some(1)).unwrap => *(given_<:<_T_T[Option[Int]])
[54:19..54:23):Some => *.apply[Int]
Expand Down