Skip to content

Use GADT constraints in maximiseType #15544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ sealed abstract class GadtConstraint extends Showable {
/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type

def symbols: List[Symbol]

def fresh: GadtConstraint

/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
Expand Down Expand Up @@ -193,12 +195,7 @@ final class ProperGadtConstraint private(
case null => null
// TODO: Improve flow typing so that ascription becomes redundant
case tv: TypeVar =>
def retrieveBounds: TypeBounds =
bounds(tv.origin) match {
case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) =>
TypeAlias(reverseMapping(tpr).nn.typeRef)
case tb => tb
}
def retrieveBounds: TypeBounds = externalize(bounds(tv.origin)).bounds
retrieveBounds
//.showing(i"gadt bounds $sym: $result", gadts)
//.ensuring(containsNoInternalTypes(_))
Expand All @@ -222,6 +219,8 @@ final class ProperGadtConstraint private(
res
}

override def symbols: List[Symbol] = mapping.keys

override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
mapping,
Expand All @@ -247,13 +246,7 @@ final class ProperGadtConstraint private(
override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = TypeComparer.isSameType(tp1, tp2)

override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds =
val externalizeMap = new TypeMap {
def apply(tp: Type): Type = tp match {
case tpr: TypeParamRef => externalize(tpr)
case tp => mapOver(tp)
}
}
externalizeMap(constraint.nonParamBounds(param)).bounds
externalize(constraint.nonParamBounds(param)).bounds

override def fullLowerBound(param: TypeParamRef)(using Context): Type =
constraint.minLower(param).foldLeft(nonParamBounds(param).lo) {
Expand All @@ -270,27 +263,28 @@ final class ProperGadtConstraint private(

// ---- Private ----------------------------------------------------------

private def externalize(param: TypeParamRef)(using Context): Type =
reverseMapping(param) match {
private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match
case param: TypeParamRef => reverseMapping(param) match
case sym: Symbol => sym.typeRef
case null => param
}
case null => param
case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap))
case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp)

private class ExternalizeMap(using Context) extends TypeMap:
def apply(tp: Type): Type = externalize(tp, this)(using mapCtx)

private def tvarOrError(sym: Symbol)(using Context): TypeVar =
mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN

private def containsNoInternalTypes(
tp: Type,
acc: TypeAccumulator[Boolean] | Null = null
)(using Context): Boolean = tp match {
private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match {
case tpr: TypeParamRef => !reverseMapping.contains(tpr)
case tv: TypeVar => !reverseMapping.contains(tv.origin)
case tp =>
(if (acc != null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
(if (theAcc != null) theAcc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp)
}

private class ContainsNoInternalTypesAccumulator(using Context) extends TypeAccumulator[Boolean] {
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp)
override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp, this)
}

// ---- Debug ------------------------------------------------------------
Expand Down Expand Up @@ -325,6 +319,8 @@ final class ProperGadtConstraint private(

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")

override def symbols: List[Symbol] = Nil

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ trait Applications extends Compatibility {
// Constraining only fails if the pattern cannot possibly match,
// but useless pattern checks detect more such cases, so we simply rely on them instead.
withMode(Mode.GadtConstraintInference)(TypeComparer.constrainPatternType(unapplyArgType, selType))
val patternBound = maximizeType(unapplyArgType, tree.span)
val patternBound = maximizeType(unapplyArgType, unapplyFn.span.endPos)
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")
unapplyArgType
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/typer/Inferencing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import collection.mutable

import scala.annotation.internal.sharable

import config.Printers.gadts

object Inferencing {

import tpd._
Expand Down Expand Up @@ -411,10 +409,15 @@ object Inferencing {
Stats.record("maximizeType")
val vs = variances(tp)
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
vs foreachBinding { (tvar, v) =>
if !tvar.isInstantiated then
if (v == 1) tvar.instantiate(fromBelow = false)
else if (v == -1) tvar.instantiate(fromBelow = true)
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
// Eg neg/i14983 the C in Node[+C] occurs in GADT bound X >: List[C] so maximising to Node[Any] is unsound
// Eg pos/precise-pattern-type the T in Tree[-T] doesn't occur in any GADT bound so can maximise to Tree[Type]
val safeToInstantiate = v != 0 && gadtBounds.forall(!tvar.occursIn(_))
if safeToInstantiate then tvar.instantiate(fromBelow = v == -1)
else {
val bounds = TypeComparer.fullBounds(tvar.origin)
if bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) then
Expand Down
4 changes: 1 addition & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3764,9 +3764,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
res
} =>
// Insert an explicit cast, so that -Ycheck in later phases succeeds.
// I suspect, but am not 100% sure that this might affect inferred types,
// if the expected type is a supertype of the GADT bound. It would be good to come
// up with a test case for this.
// The check "safeToInstantiate" in `maximizeType` works to prevent unsound GADT casts.
val target =
if tree.tpe.isSingleton then
val conj = AndType(tree.tpe, pt)
Expand Down
23 changes: 23 additions & 0 deletions tests/neg/i14983.co-contra.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
case class Showing[-C](show: C => String)

sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[-C](l: Showing[C]) extends Tree[Showing[C]]

object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v
case Node(x) =>
// tree: Tree[X] vs Node[C] aka Tree[Showing[C]]
// PTC: X >: Showing[C]
// max: Node[C] to Node[Nothing], instantiating C := Nothing, which makes X >: Showing[Nothing]
// adapt: Showing[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
Showing[String](_ + " boom!") // error: Found: Showing[String] Required: X where: X is a type in method meth with bounds >: Showing[C$1]
// after fix:
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: Showing[C$1]
// adapt: Showing[String] <: X = Fail, because String !<: C$1

def main(args: Array[String]): Unit =
val tree = Node(Showing[Int](_.toString))
val res = meth(tree)
println(res.show(42)) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
15 changes: 15 additions & 0 deletions tests/neg/i14983.contra.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
sealed trait Show[-A]
final case class Pure[-B](showB: B => String) extends Show[B]
final case class Many[-C](showL: List[C] => String) extends Show[List[C]]

object Test:
def meth[X](show: Show[X]): X => String = show match
case Pure(showB) => showB
case Many(showL) =>
val res = (xs: List[String]) => xs.head.length.toString
res // error: Found: List[String] => String Required: X => String where: X is a type in method meth with bounds <: List[C$1]

def main(args: Array[String]): Unit =
val show = Many((is: List[Int]) => (is.head + 1).toString)
val fn = meth(show)
assert(fn(List(42)) == "43") // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
22 changes: 22 additions & 0 deletions tests/neg/i14983.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[+C](l: List[C]) extends Tree[List[C]]

// The original test case, minimised.
object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X
case Node(x) =>
// tree: Tree[X] vs Node[C] aka Tree[List[C]]
// PTC: X >: List[C]
// max: Node[C] => Node[Any], instantiating C := Any, which makes X >: List[Any]
// adapt: List[String] <: X = OKwithGADTUsed; insert GADT cast asInstanceOf[X]
List("boom") // error: Found: List[String] Required: X where: X is a type in method meth with bounds >: List[C$1]
// after fix:
// max: Node[C] => Node[C$1], instantiating C := C$1, a new symbol, so X >: List[C$1]
// adapt: List[String] <: X = Fail, because String !<: C$1

def main(args: Array[String]): Unit =
val tree = Node(List(42))
val res = meth(tree)
assert(res.head == 42) // was: ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
14 changes: 14 additions & 0 deletions tests/run/i14983.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
sealed trait Tree[+A]
final case class Leaf[+B](b: B) extends Tree[B]
final case class Node[+C](l: List[C]) extends Tree[List[C]]

// A version of the original test case that is sound so should typecheck.
object Test:
def meth[X](tree: Tree[X]): X = tree match
case Leaf(v) => v // ok: Tree[X] vs Leaf[B], PTC: X >: B, max: Leaf[B] => Leaf[X], x: X <:< X
case Node(x) => x // ok: Tree[X] vs Node[C], PTC: X >: List[C], max: Node[C] => Node[C$1], x: C$1 <:< X, w/ GADT cast

def main(args: Array[String]): Unit =
val tree = Node(List(42))
val res = meth(tree)
assert(res.head == 42) // ok