Skip to content

Commit e15d3b9

Browse files
committed
Refix constraining two intersections
In trying to fix i18453 I changed how a method was eta-expanded, making it infer differently. That difference in type inference affected a lot of code. So I went back to the problem and tried to fix the over-constraining that was occurring previously.
1 parent 4d45087 commit e15d3b9

File tree

9 files changed

+134
-4
lines changed

9 files changed

+134
-4
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,20 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
383383
}
384384
compareSuper
385385
case AndType(tp21, tp22) =>
386-
recur(tp1, tp21) && recur(tp1, tp22)
386+
tp1 match
387+
case AndType(tp11, tp12) =>
388+
// In A1 & B2 <:< B1 & A2
389+
// Avoid over-constraining by attempting to satisfy the subtype
390+
// with either combination of the intersection parts
391+
// (1s and 2s, or As and Bs).
392+
// We have to make sure to rollback failed part attempts.
393+
either(
394+
rollbackConstraintsUnless(recur(tp11, tp21) && recur(tp12, tp22)),
395+
rollbackConstraintsUnless(recur(tp11, tp22) && recur(tp12, tp21)),
396+
) ||
397+
recur(tp1, tp21) && recur(tp1, tp22)
398+
case _ =>
399+
recur(tp1, tp21) && recur(tp1, tp22)
387400
case OrType(tp21, tp22) =>
388401
if (tp21.stripTypeVar eq tp22.stripTypeVar) recur(tp1, tp21)
389402
else secondTry
@@ -1010,7 +1023,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10101023
if (tp2a ne tp2) // Follow the alias; this might avoid truncating the search space in the either below
10111024
return recur(tp1, tp2a)
10121025

1013-
// Rewrite (T111 | T112) & T12 <: T2 to (T111 & T12) <: T2 and (T112 | T12) <: T2
1026+
// Rewrite (T111 | T112) & T12 <: T2 to (T111 & T12) <: T2 and (T112 & T12) <: T2
10141027
// and analogously for T11 & (T121 | T122) & T12 <: T2
10151028
// `&' types to the left of <: are problematic, because
10161029
// we have to choose one constraint set or another, which might cut off
@@ -1982,6 +1995,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
19821995
else op2
19831996
end necessaryEither
19841997

1998+
inline def rollbackConstraintsUnless(inline op: Boolean): Boolean =
1999+
val saved = constraint
2000+
var result = false
2001+
try result = ctx.gadtState.rollbackGadtUnless(op)
2002+
finally if !result then constraint = saved
2003+
result
2004+
19852005
/** Decompose into conjunction of types each of which has only a single refinement */
19862006
def decomposeRefinements(tp: Type, refines: List[(Name, Type)]): Type = tp match
19872007
case RefinedType(parent, rname, rinfo) =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4194,7 +4194,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
41944194
val funExpected = functionExpected
41954195
val arity =
41964196
if funExpected then
4197-
defn.functionArity(ptNorm)
4197+
if !isFullyDefined(pt, ForceDegree.none) && isFullyDefined(wtp, ForceDegree.none) then
4198+
// if method type is fully defined, but expected type is not,
4199+
// prioritize method parameter types as parameter types of the eta-expanded closure
4200+
0
4201+
else defn.functionArity(ptNorm)
41984202
else
41994203
val nparams = wtp.paramInfos.length
42004204
if nparams > 1

tests/neg/i5976.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
object Test {
22
def f(i: => Int) = i + i
3-
val res = List(42).map(f)
3+
val res = List(42).map(f) // error
44

55
val g: (=> Int) => Int = f
66
val h: Int => Int = g // error

tests/pos/i19001.case1.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import java.util.concurrent.CompletionStage
2+
import scala.concurrent.Future
3+
4+
trait ActorRef[-T]:
5+
def ask[Res](replyTo: ActorRef[Res] => T): Future[Res] = ???
6+
7+
implicit final class FutureOps[T](private val f: Future[T]) extends AnyVal:
8+
def asJava: CompletionStage[T] = ???
9+
10+
class AskPattern[Req, Res]:
11+
val actor: ActorRef[Req] = ???
12+
val messageFactory: ActorRef[Res] => Req = ???
13+
14+
def failing(): CompletionStage[Res] = actor.ask(messageFactory.apply).asJava
15+
def workaround1(): CompletionStage[Res] = actor.ask[Res](messageFactory.apply).asJava
16+
def workaround2(): CompletionStage[Res] = actor.ask(messageFactory).asJava
17+
18+
val jMessageFactory: java.util.function.Function[ActorRef[Res], Req] = ???
19+
def originalFailingCase(): CompletionStage[Res] = actor.ask(jMessageFactory.apply).asJava

tests/pos/i19001.case2.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.util.{Try, Success, Failure}
2+
3+
trait ActorRef[-T]
4+
trait ActorContext[T]:
5+
def ask[Req, Res](target: ActorRef[Req], createRequest: ActorRef[Res] => Req)(mapResponse: Try[Res] => T): Unit
6+
7+
@main def Test =
8+
val context: ActorContext[Int] = ???
9+
val askMeRef: ActorRef[Request] = ???
10+
11+
case class Request(replyTo: ActorRef[Int])
12+
13+
context.ask(askMeRef, Request.apply) {
14+
case Success(res) => res // error: expected Int, got Any
15+
case Failure(ex) => throw ex
16+
}

tests/pos/i19001.case3.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
trait IO[A]:
2+
def map[B](f: A => B): IO[B] = ???
3+
4+
trait RenderResult[T]:
5+
def value: T
6+
7+
def IOasync[T](f: (Either[Throwable, T] => Unit) => Unit): IO[T] = ???
8+
9+
def render[T]: IO[T] = {
10+
def register(cb: Either[Throwable, RenderResult[T]] => Unit): Unit = ???
11+
IOasync(register).map(_.value) // map should take RenderResult[T], but uses Any
12+
}

tests/pos/i19009.case1.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
trait Player[+P]
2+
trait RatingPeriod[P]:
3+
def games: Map[P, Vector[ScoreVsPlayer[P]]]
4+
5+
trait ScoreVsPlayer[+P]
6+
7+
def updated[P](playerID: P, matchResults: IndexedSeq[ScoreVsPlayer[P]], lookup: P => Option[Player[P]]): Player[P] = ???
8+
9+
trait Leaderboard[P]:
10+
def playersByIdInNoParticularOrder: Map[P, Player[P]]
11+
12+
def after[P2 >: P](ratingPeriod: RatingPeriod[? <: P]): Leaderboard[P2] =
13+
val competingPlayers = ratingPeriod.games.iterator.map { (id, matchResults) =>
14+
updated(id, matchResults, playersByIdInNoParticularOrder.get) // error
15+
// workaround:
16+
updated[P](id, matchResults, playersByIdInNoParticularOrder.get)
17+
}
18+
???

tests/pos/i19009.case2.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object NodeOrdering:
2+
def postOrderNumbering[NodeType](cfgEntry: NodeType, expand: NodeType => Iterator[NodeType]): Map[NodeType, Int] = ???
3+
4+
trait CfgNode
5+
trait Method extends CfgNode
6+
7+
def postOrder =
8+
def method: Method = ???
9+
def expand(x: CfgNode): Iterator[CfgNode] = ???
10+
NodeOrdering.postOrderNumbering(method, expand)

tests/pos/i19009.case3.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
trait Bound[+E]
2+
3+
trait SegmentT[E, +S]
4+
object SegmentT:
5+
trait WithPrev[E, +S] extends SegmentT[E, S]
6+
7+
trait SegmentSeqT[E, +S]:
8+
def getSegmentForBound(bound: Bound[E]): SegmentT[E, S] with S
9+
10+
abstract class AbstractSegmentSeq[E, +S] extends SegmentSeqT[E, S]
11+
12+
trait MappedSegmentBase[E, S]
13+
14+
type MappedSegment[E, S] = AbstractMappedSegmentSeq.MappedSegment[E, S]
15+
16+
object AbstractMappedSegmentSeq:
17+
type MappedSegment[E, S] = SegmentT[E, MappedSegmentBase[E, S]] with MappedSegmentBase[E, S]
18+
19+
abstract class AbstractMappedSegmentSeq[E, S]
20+
extends AbstractSegmentSeq[E, MappedSegmentBase[E, S]]:
21+
def originalSeq: SegmentSeqT[E, S]
22+
23+
final override def getSegmentForBound(bound: Bound[E]): MappedSegment[E, S] =
24+
searchFrontMapper(frontMapperGeneral, originalSeq.getSegmentForBound(bound))
25+
26+
protected final def frontMapperGeneral(original: SegmentT[E, S]): MappedSegment[E, S] = ???
27+
28+
protected def searchFrontMapper[Seg >: SegmentT.WithPrev[E, S] <: SegmentT[E, S], R](
29+
mapper: Seg => R,
30+
original: Seg
31+
): R = ???

0 commit comments

Comments
 (0)