Skip to content

Commit f7bb6e1

Browse files
committed
infer: Use GADT constraints in maximiseType
Consider the GADT constraints during Inferencing's maximiseType to avoid instantiating type variables that lead to GADT casting inserting unsound casts.
1 parent aa7c59b commit f7bb6e1

File tree

5 files changed

+51
-7
lines changed

5 files changed

+51
-7
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ sealed abstract class GadtConstraint extends Showable {
4949
/** See [[ConstraintHandling.approximation]] */
5050
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
5151

52+
def symbols: List[Symbol]
53+
5254
def fresh: GadtConstraint
5355

5456
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
@@ -209,6 +211,8 @@ final class ProperGadtConstraint private(
209211
res
210212
}
211213

214+
override def symbols: List[Symbol] = mapping.keys
215+
212216
override def fresh: GadtConstraint = new ProperGadtConstraint(
213217
myConstraint,
214218
mapping,
@@ -307,6 +311,8 @@ final class ProperGadtConstraint private(
307311

308312
override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
309313

314+
override def symbols: List[Symbol] = Nil
315+
310316
override def fresh = new ProperGadtConstraint
311317
override def restore(other: GadtConstraint): Unit =
312318
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import collection.mutable
1616

1717
import scala.annotation.internal.sharable
1818

19-
import config.Printers.gadts
20-
2119
object Inferencing {
2220

2321
import tpd._
@@ -411,10 +409,16 @@ object Inferencing {
411409
Stats.record("maximizeType")
412410
val vs = variances(tp)
413411
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
412+
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
414413
vs foreachBinding { (tvar, v) =>
415414
if !tvar.isInstantiated then
416-
if (v == 1) tvar.instantiate(fromBelow = false)
417-
else if (v == -1) tvar.instantiate(fromBelow = true)
415+
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
416+
// then check the tvar doesn't occur in the opposite GADT bound (lower/upper) within any of the GADT bounds
417+
// if it doesn't occur then it's safe to instantiate the tvar
418+
// Eg neg/i14983 the C in Node[+C] is in the GADT lower bound X >: List[C] so maximising to Node[Any] is unsound
419+
// Eg pos/precise-pattern-type the T in Tree[-T] is in no GADT upper bound so can maximise to Tree[Type]
420+
val safeToInstantiate = v != 0 && gadtBounds.forall(tb => !tvar.occursIn(if v == 1 then tb.lo else tb.hi))
421+
if safeToInstantiate then tvar.instantiate(fromBelow = v == -1)
418422
else {
419423
val bounds = TypeComparer.fullBounds(tvar.origin)
420424
if bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) then

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3762,9 +3762,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
37623762
res
37633763
} =>
37643764
// Insert an explicit cast, so that -Ycheck in later phases succeeds.
3765-
// I suspect, but am not 100% sure that this might affect inferred types,
3766-
// if the expected type is a supertype of the GADT bound. It would be good to come
3767-
// up with a test case for this.
3765+
// The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
37683766
val target =
37693767
if tree.tpe.isSingleton then
37703768
val conj = AndType(tree.tpe, pt)

tests/neg/i14983.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
sealed trait Tree[+A]
2+
final case class Leaf[+B](b: B) extends Tree[B]
3+
final case class Node[+C](l: List[C]) extends Tree[List[C]]
4+
5+
// The original test case, minimised.
6+
object Test:
7+
def meth[X](tree: Tree[X]): X = tree match
8+
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X
9+
case Node(x) =>
10+
// tree: Tree[X] vs Node[C] aka Tree[List[C]]
11+
// PTC: X >: List[C]
12+
// max: Node[C] => Node[Any], instantiating C := Any, which makes X >: List[Any]
13+
// adapt: List[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
14+
List("boom") // error: Found: List[String] Required: X where: X is a type in method meth with bounds >: List[C$1]
15+
// after fix:
16+
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: List[C$1]
17+
// adapt: List[String] <: X = Fail, because String !<: C$1
18+
19+
def main(args: Array[String]): Unit =
20+
val tree = Node(List(42))
21+
val res = meth(tree)
22+
assert(res.head == 42) // was: ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer

tests/run/i14983.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
sealed trait Tree[+A]
2+
final case class Leaf[+B](b: B) extends Tree[B]
3+
final case class Node[+C](l: List[C]) extends Tree[List[C]]
4+
5+
// A version of the original test case that is sound so should typecheck.
6+
object Test:
7+
def meth[X](tree: Tree[X]): X = tree match
8+
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X <:< X
9+
case Node(x) => x // ok: Tree[X] vs Node[C], PTC: X >: List[C], max: Node[C] => Node[C$1], x: C$1 <:< X, w/ GADT cast
10+
11+
def main(args: Array[String]): Unit =
12+
val tree = Node(List(42))
13+
val res = meth(tree)
14+
assert(res.head == 42) // ok

0 commit comments

Comments
 (0)