diff --git a/compiler/src/dotty/tools/dotc/core/Constraint.scala b/compiler/src/dotty/tools/dotc/core/Constraint.scala index 91bedf35948b..b40b806c85bb 100644 --- a/compiler/src/dotty/tools/dotc/core/Constraint.scala +++ b/compiler/src/dotty/tools/dotc/core/Constraint.scala @@ -45,6 +45,18 @@ abstract class Constraint extends Showable { /** The parameters that are known to be greater wrt <: than `param` */ def upper(param: TypeParamRef): List[TypeParamRef] + /** The lower dominator set. + * + * This is like `lower`, except that each parameter returned is no smaller than every other returned parameter. + */ + def minLower(param: TypeParamRef): List[TypeParamRef] + + /** The upper dominator set. + * + * This is like `upper`, except that each parameter returned is no greater than every other returned parameter. + */ + def minUpper(param: TypeParamRef): List[TypeParamRef] + /** lower(param) \ lower(butNot) */ def exclusiveLower(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef] @@ -58,15 +70,6 @@ abstract class Constraint extends Showable { */ def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds - /** The lower bound of `param` including all known-to-be-smaller parameters */ - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type - - /** The upper bound of `param` including all known-to-be-greater parameters */ - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type - - /** The bounds of `param` including all known-to-be-smaller and -greater parameters */ - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds - /** A new constraint which is derived from this constraint by adding * entries for all type parameters of `poly`. * @param tvars A list of type variables associated with the params, diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index 0560866a3e6e..4afd55efdefb 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -2,10 +2,13 @@ package dotty.tools package dotc package core -import Types._, Contexts._, Symbols._ +import Types._ +import Contexts._ +import Symbols._ import Decorators._ import config.Config import config.Printers.{constr, typr} +import dotty.tools.dotc.reporting.trace /** Methods for adding constraints and solving them. * @@ -66,6 +69,22 @@ trait ConstraintHandling[AbstractContext] { case tp => tp } + def nonParamBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = constraint.nonParamBounds(param) + + def fullLowerBound(param: TypeParamRef)(implicit actx: AbstractContext): Type = + (nonParamBounds(param).lo /: constraint.minLower(param))(_ | _) + + def fullUpperBound(param: TypeParamRef)(implicit actx: AbstractContext): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param))(_ & _) + + /** Full bounds of `param`, including other lower/upper params. + * + * Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds` + * of some param when comparing types might lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = + nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) + protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean = !constraint.contains(param) || { def occursIn(bound: Type): Boolean = { @@ -262,7 +281,7 @@ trait ConstraintHandling[AbstractContext] { } constraint.entry(param) match { case _: TypeBounds => - val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param) + val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param) val inst = avoidParam(bound) typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}") inst diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 6e91741ae444..cbc3a12f0f8b 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -139,9 +139,9 @@ object Contexts { final def importInfo: ImportInfo = _importInfo /** The current bounds in force for type parameters appearing in a GADT */ - private[this] var _gadt: GADTMap = _ - protected def gadt_=(gadt: GADTMap): Unit = _gadt = gadt - final def gadt: GADTMap = _gadt + private[this] var _gadt: GadtConstraint = _ + protected def gadt_=(gadt: GadtConstraint): Unit = _gadt = gadt + final def gadt: GadtConstraint = _gadt /** The history of implicit searches that are currently active */ private[this] var _searchHistory: SearchHistory = null @@ -534,7 +534,7 @@ object Contexts { def setTypeAssigner(typeAssigner: TypeAssigner): this.type = { this.typeAssigner = typeAssigner; this } def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) } def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this } - def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this } + def setGadt(gadt: GadtConstraint): this.type = { this.gadt = gadt; this } def setFreshGADTBounds: this.type = setGadt(gadt.fresh) def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this } def setSource(source: SourceFile): this.type = { this.source = source; this } @@ -617,7 +617,7 @@ object Contexts { store = initialStore.updated(settingsStateLoc, settingsGroup.defaultState) typeComparer = new TypeComparer(this) searchHistory = new SearchRoot - gadt = EmptyGADTMap + gadt = EmptyGadtConstraint } @sharable object NoContext extends Context(null) { @@ -774,233 +774,4 @@ object Contexts { if (thread == null) thread = Thread.currentThread() else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase") } - - sealed abstract class GADTMap { - def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit - def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean - def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds - def contains(sym: Symbol)(implicit ctx: Context): Boolean - def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type - def debugBoundsDescription(implicit ctx: Context): String - def fresh: GADTMap - def restore(other: GADTMap): Unit - def isEmpty: Boolean - } - - final class SmartGADTMap private ( - private var myConstraint: Constraint, - private var mapping: SimpleIdentityMap[Symbol, TypeVar], - private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private var boundCache: SimpleIdentityMap[Symbol, TypeBounds] - ) extends GADTMap with ConstraintHandling[Context] { - import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} - - def this() = this( - myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), - mapping = SimpleIdentityMap.Empty, - reverseMapping = SimpleIdentityMap.Empty, - boundCache = SimpleIdentityMap.Empty - ) - - implicit override def ctx(implicit ctx: Context): Context = ctx - - override protected def constraint = myConstraint - override protected def constraint_=(c: Constraint) = myConstraint = c - - override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) - override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) - - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym) - - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = try { - boundCache = SimpleIdentityMap.Empty - boundAdditionInProgress = true - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } - - def externalizedSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = { - val externalizedTp1 = removeTypeVars(tp1) - val externalizedTp2 = removeTypeVars(tp2) - - ( - if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2 - else externalizedTp2 frozen_<:< externalizedTp1 - ).reporting({ res => - val descr = i"$externalizedTp1 frozen_${if (isSubtype) "<:<" else ">:>"} $externalizedTp2" - i"$descr = $res" - }, gadts) - } - - val symTvar: TypeVar = stripInternalTypeVar(tvar(sym)) match { - case tv: TypeVar => tv - case inst => - val externalizedInst = removeTypeVars(inst) - gadts.println(i"instantiated: $sym -> $externalizedInst") - return if (isUpper) isSubType(externalizedInst , bound) else isSubType(bound, externalizedInst) - } - - val internalizedBound = insertTypeVars(bound) - ( - stripInternalTypeVar(internalizedBound) match { - case boundTvar: TypeVar => - if (boundTvar eq symTvar) true - else if (isUpper) addLess(symTvar.origin, boundTvar.origin) - else addLess(boundTvar.origin, symTvar.origin) - case bound => - if (externalizedSubtype(symTvar, bound, isSubtype = !isUpper)) { - gadts.println(i"manually unifying $symTvar with $bound") - constraint = constraint.updateEntry(symTvar.origin, bound) - true - } - else if (isUpper) addUpperBound(symTvar.origin, bound) - else addLowerBound(symTvar.origin, bound) - } - ).reporting({ res => - val descr = if (isUpper) "upper" else "lower" - val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )" - }, gadts) - } finally boundAdditionInProgress = false - - override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { - mapping(sym) match { - case null => null - case tv => - def retrieveBounds: TypeBounds = { - val tb = constraint.fullBounds(tv.origin) - removeTypeVars(tb).asInstanceOf[TypeBounds] - } - ( - if (boundAdditionInProgress || ctx.mode.is(Mode.GADTflexible)) retrieveBounds - else boundCache(sym) match { - case tb: TypeBounds => tb - case null => - val bounds = retrieveBounds - boundCache = boundCache.updated(sym, bounds) - bounds - } - ).reporting({ res => - // i"gadt bounds $sym: $res" - "" - }, gadts) - } - } - - override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null - - override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { - val res = removeTypeVars(approximation(tvar(sym).origin, fromBelow = fromBelow)) - gadts.println(i"approximating $sym ~> $res") - res - } - - override def fresh: GADTMap = new SmartGADTMap( - myConstraint, - mapping, - reverseMapping, - boundCache - ) - - def restore(other: GADTMap): Unit = other match { - case other: SmartGADTMap => - this.myConstraint = other.myConstraint - this.mapping = other.mapping - this.reverseMapping = other.reverseMapping - this.boundCache = other.boundCache - case _ => ; - } - - override def isEmpty: Boolean = mapping.size == 0 - - // ---- Private ---------------------------------------------------------- - - private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = { - mapping(sym) match { - case tv: TypeVar => - tv - case null => - val res = { - import NameKinds.DepParamName - // avoid registering the TypeVar with TyperState / TyperState#constraint - // - we don't want TyperState instantiating these TypeVars - // - we don't want TypeComparer constraining these TypeVars - val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)( - pt => (sym.info match { - case tb @ TypeBounds(_, hi) if hi.isLambdaSub => tb - case _ => TypeBounds.empty - }) :: Nil, - pt => defn.AnyType) - new TypeVar(poly.paramRefs.head, creatorState = null) - } - gadts.println(i"GADTMap: created tvar $sym -> $res") - constraint = constraint.add(res.origin.binder, res :: Nil) - mapping = mapping.updated(sym, res) - reverseMapping = reverseMapping.updated(res.origin, sym) - res - } - } - - private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { - case tp: TypeRef => - val sym = tp.typeSymbol - if (contains(sym)) tvar(sym) else tp - case _ => - (if (map != null) map else new TypeVarInsertingMap()).mapOver(tp) - } - private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap { - override def apply(tp: Type): Type = insertTypeVars(tp, this) - } - - private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match { - case tpr: TypeParamRef => - reverseMapping(tpr) match { - case null => tpr - case sym => sym.typeRef - } - case tv: TypeVar => - reverseMapping(tv.origin) match { - case null => tv - case sym => sym.typeRef - } - case _ => - (if (map != null) map else new TypeVarRemovingMap()).mapOver(tp) - } - private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap { - override def apply(tp: Type): Type = removeTypeVars(tp, this) - } - - private[this] var boundAdditionInProgress = false - - // ---- Debug ------------------------------------------------------------ - - override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) - - override def debugBoundsDescription(implicit ctx: Context): String = { - val sb = new mutable.StringBuilder - sb ++= constraint.show - sb += '\n' - mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${bounds(sym)}\n" - } - sb.result - } - } - - @sharable object EmptyGADTMap extends GADTMap { - override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds") - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound") - override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null - override def contains(sym: Symbol)(implicit ctx: Context) = false - override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation") - override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap" - override def fresh = new SmartGADTMap - override def restore(other: GADTMap): Unit = { - if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") - } - override def isEmpty: Boolean = true - } } diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala new file mode 100644 index 000000000000..11b869b8f995 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -0,0 +1,329 @@ +package dotty.tools +package dotc +package core + +import Decorators._ +import Contexts._ +import Types._ +import Symbols._ +import util.SimpleIdentityMap +import collection.mutable +import printing._ + +import scala.annotation.internal.sharable + +/** Represents GADT constraints currently in scope */ +sealed abstract class GadtConstraint extends Showable { + /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ + def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. + * + * @note this performs subtype checks between ordered symbols. + * Using this in isSubType can lead to infinite recursion. Consider `bounds` instead. + */ + def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds + + /** Is `sym1` ordered to be less than `sym2`? */ + def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean + + /** Add symbols to constraint, correctly handling inter-dependencies. + * + * @see [[ConstraintHandling.addToConstraint]] + */ + def addToConstraint(syms: List[Symbol])(implicit ctx: Context): Boolean + def addToConstraint(sym: Symbol)(implicit ctx: Context): Boolean = addToConstraint(sym :: Nil) + + /** Further constrain a symbol already present in the constraint. */ + def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean + + /** Is the symbol registered in the constraint? + * + * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. + */ + def contains(sym: Symbol)(implicit ctx: Context): Boolean + + def isEmpty: Boolean + + /** See [[ConstraintHandling.approximation]] */ + def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type + + def fresh: GadtConstraint + + /** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */ + def restore(other: GadtConstraint): Unit + + def debugBoundsDescription(implicit ctx: Context): String +} + +final class ProperGadtConstraint private( + private var myConstraint: Constraint, + private var mapping: SimpleIdentityMap[Symbol, TypeVar], + private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], +) extends GadtConstraint with ConstraintHandling[Context] { + import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} + + def this() = this( + myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty), + mapping = SimpleIdentityMap.Empty, + reverseMapping = SimpleIdentityMap.Empty + ) + + /** Exposes ConstraintHandling.subsumes */ + def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(implicit ctx: Context): Boolean = { + def extractConstraint(g: GadtConstraint) = g match { + case s: ProperGadtConstraint => s.constraint + case EmptyGadtConstraint => OrderingConstraint.empty + } + subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + } + + override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = { + import NameKinds.DepParamName + + val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })( + pt => params.map { param => + // In bound type `tp`, replace the symbols in dependent positions with their internal TypeParamRefs. + // The replaced symbols will be later picked up in `ConstraintHandling#addToConstraint` + // and used as orderings. + def substDependentSyms(tp: Type, isUpper: Boolean)(implicit ctx: Context): Type = { + def loop(tp: Type) = substDependentSyms(tp, isUpper) + tp match { + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp: NamedType => + params.indexOf(tp.symbol) match { + case -1 => + mapping(tp.symbol) match { + case tv: TypeVar => tv.origin + case null => tp + } + case i => pt.paramRefs(i) + } + case tp => tp + } + } + + val tb = param.info.bounds + tb.derivedTypeBounds( + lo = substDependentSyms(tb.lo, isUpper = false), + hi = substDependentSyms(tb.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = (params, poly1.paramRefs).zipped.map { (sym, paramRef) => + val tv = new TypeVar(paramRef, creatorState = null) + mapping = mapping.updated(sym, tv) + reverseMapping = reverseMapping.updated(tv.origin, sym) + tv + } + + // The replaced symbols are picked up here. + addToConstraint(poly1, tvars).reporting({ _ => + i"added to constraint: $params%, %\n$debugBoundsDescription" + }, gadts) + } + + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = { + @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { + case tv: TypeVar => tv + case inst => + gadts.println(i"instantiated: $sym -> $inst") + return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst) + } + + val internalizedBound = bound match { + case nt: NamedType => + val ntTvar = mapping(nt.symbol) + if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound + case _ => bound + } + ( + internalizedBound match { + case boundTvar: TypeVar => + if (boundTvar eq symTvar) true + else if (isUpper) addLess(symTvar.origin, boundTvar.origin) + else addLess(boundTvar.origin, symTvar.origin) + case bound => + val oldUpperBound = bounds(symTvar.origin) + // If we have bounds: + // F >: [t] => List[t] <: [t] => Any + // and we want to record that: + // F <: [+A] => List[A] + // we need to adapt the variance and instead record that: + // F <: [A] => List[A] + // We cannot record the original bound, since it is false that: + // [t] => List[t] <: [+A] => List[A] + // + // Note that the following code is accepted: + // class Foo[F[t] >: List[t]] + // type T = Foo[List] + // precisely because Foo[List] is desugared to Foo[[A] => List[A]]. + // + // Ideally we'd adapt the bound in ConstraintHandling#addOneBound, + // but doing it there actually interferes with type inference. + val bound1 = bound.adaptHkVariances(oldUpperBound) + if (isUpper) addUpperBound(symTvar.origin, bound1) + else addLowerBound(symTvar.origin, bound1) + } + ).reporting({ res => + val descr = if (isUpper) "upper" else "lower" + val op = if (isUpper) "<:" else ">:" + i"adding $descr bound $sym $op $bound = $res" + }, gadts) + } + + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = + constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = + mapping(sym) match { + case null => null + case tv => + fullBounds(tv.origin) + .ensuring(containsNoInternalTypes(_)) + } + + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = { + mapping(sym) match { + case null => null + case tv => + def retrieveBounds: TypeBounds = + bounds(tv.origin) match { + case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => + TypeAlias(reverseMapping(tpr).typeRef) + case tb => tb + } + retrieveBounds + //.reporting({ res => i"gadt bounds $sym: $res" }, gadts) + //.ensuring(containsNoInternalTypes(_)) + } + } + + override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null + + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = { + val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow) + gadts.println(i"approximating $sym ~> $res") + res + } + + override def fresh: GadtConstraint = new ProperGadtConstraint( + myConstraint, + mapping, + reverseMapping + ) + + def restore(other: GadtConstraint): Unit = other match { + case other: ProperGadtConstraint => + this.myConstraint = other.myConstraint + this.mapping = other.mapping + this.reverseMapping = other.reverseMapping + case _ => ; + } + + override def isEmpty: Boolean = mapping.size == 0 + + // ---- Protected/internal ----------------------------------------------- + + implicit override def ctx(implicit ctx: Context): Context = ctx + + override protected def constraint = myConstraint + override protected def constraint_=(c: Constraint) = myConstraint = c + + override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2) + override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2) + + override def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = + constraint.nonParamBounds(param) match { + case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr)) + case tb => tb + } + + override def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).lo /: constraint.minLower(param)) { + (t, u) => t | externalize(u) + } + + override def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = + (nonParamBounds(param).hi /: constraint.minUpper(param)) { + (t, u) => t & externalize(u) + } + + // ---- Private ---------------------------------------------------------- + + private[this] def externalize(param: TypeParamRef)(implicit ctx: Context): Type = + reverseMapping(param) match { + case sym: Symbol => sym.typeRef + case null => param + } + + private[this] def tvarOrError(sym: Symbol)(implicit ctx: Context): TypeVar = + mapping(sym).ensuring(_ ne null, i"not a constrainable symbol: $sym") + + private[this] def containsNoInternalTypes( + tp: Type, + acc: TypeAccumulator[Boolean] = null + )(implicit ctx: Context): Boolean = tp match { + case tpr: TypeParamRef => !reverseMapping.contains(tpr) + case tv: TypeVar => !reverseMapping.contains(tv.origin) + case tp => + (if (acc ne null) acc else new ContainsNoInternalTypesAccumulator()).foldOver(true, tp) + } + + private[this] class ContainsNoInternalTypesAccumulator(implicit ctx: Context) extends TypeAccumulator[Boolean] { + override def apply(x: Boolean, tp: Type): Boolean = x && containsNoInternalTypes(tp) + } + + // ---- Debug ------------------------------------------------------------ + + override def constr_println(msg: => String): Unit = gadtsConstr.println(msg) + + override def toText(printer: Printer): Texts.Text = constraint.toText(printer) + + override def debugBoundsDescription(implicit ctx: Context): String = { + val sb = new mutable.StringBuilder + sb ++= constraint.show + sb += '\n' + mapping.foreachBinding { case (sym, _) => + sb ++= i"$sym: ${fullBounds(sym)}\n" + } + sb.result + } +} + +@sharable object EmptyGadtConstraint extends GadtConstraint { + override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null + + override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + + override def isEmpty: Boolean = true + + override def contains(sym: Symbol)(implicit ctx: Context) = false + + override def addToConstraint(params: List[Symbol])(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGadtConstraint.addBound") + + override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGadtConstraint.approximation") + + override def fresh = new ProperGadtConstraint + override def restore(other: GadtConstraint): Unit = { + if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap") + } + + override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGadtConstraint" + + override def toText(printer: Printer): Texts.Text = "EmptyGadtConstraint" +} diff --git a/compiler/src/dotty/tools/dotc/core/Mode.scala b/compiler/src/dotty/tools/dotc/core/Mode.scala index 430d0b062c84..81b9fc5ea5c4 100644 --- a/compiler/src/dotty/tools/dotc/core/Mode.scala +++ b/compiler/src/dotty/tools/dotc/core/Mode.scala @@ -49,7 +49,7 @@ object Mode { /** We are in a pattern alternative */ val InPatternAlternative: Mode = newMode(7, "InPatternAlternative") - /** Allow GADTFlexType labelled types to have their bounds adjusted */ + /** Infer GADT constraints during type comparisons `A <:< B` */ val GADTflexible: Mode = newMode(8, "GADTflexible") /** Assume -language:strictEquality */ diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 2f568dfe7750..869c8330a5a3 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -196,15 +196,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds, def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = entry(param).bounds - def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).lo /: minLower(param))(_ | _) - - def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type = - (nonParamBounds(param).hi /: minUpper(param))(_ & _) - - def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds = - nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param)) - def typeVarOfParam(param: TypeParamRef): Type = { val entries = boundsMap(param.binder) if (entries == null) NoType diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index ef81b8bb3bf9..b3d41754dc87 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -209,16 +209,10 @@ trait Symbols { this: Context => modFlags | PackageCreationFlags, clsFlags | PackageCreationFlags, Nil, decls) - /** Define a new symbol associated with a Bind or pattern wildcard and - * make it gadt narrowable. - */ - def newPatternBoundSymbol(name: Name, info: Type, span: Span): Symbol = { + /** Define a new symbol associated with a Bind or pattern wildcard and, by default, make it gadt narrowable. */ + def newPatternBoundSymbol(name: Name, info: Type, span: Span, addToGadt: Boolean = true): Symbol = { val sym = newSymbol(owner, name, Case, info, coord = span) - if (name.isTypeName) { - val bounds = info.bounds - gadt.addBound(sym, bounds.lo, isUpper = false) - gadt.addBound(sym, bounds.hi, isUpper = true) - } + if (addToGadt && name.isTypeName) gadt.addToConstraint(sym) sym } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 8a69184f0846..86776c4fda00 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -442,8 +442,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { val gbounds2 = gadtBounds(tp2.symbol) (gbounds2 != null) && (isSubTypeWhenFrozen(tp1, gbounds2.lo) || + (tp1 match { + case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false + }) || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) && - GADTusage(tp2.symbol) + { tp1.isRef(NothingClass) || GADTusage(tp2.symbol) } } isSubApproxHi(tp1, info2.lo) || compareGADT || fourthTry @@ -702,7 +712,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { (gbounds1 != null) && (isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) && - GADTusage(tp1.symbol) + { tp2.isRef(AnyClass) || GADTusage(tp1.symbol) } } isSubType(hi1, tp2, approx.addLow) || compareGADT case _ => @@ -1209,6 +1219,17 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { fix(tp) } + /** Returns true iff the result of evaluating either `op1` or `op2` is true and approximates resulting constraints. + * + * If we're _not_ in GADTFlexible mode, we try to keep the smaller of the two constraints. + * If we're _in_ GADTFlexible mode, we keep the smaller constraint if any, or no constraint at all. + * + * @see [[sufficientEither]] for the normal case + * @see [[necessaryEither]] for the GADTFlexible case + */ + private def either(op1: => Boolean, op2: => Boolean): Boolean = + if (ctx.mode.is(Mode.GADTflexible)) necessaryEither(op1, op2) else sufficientEither(op1, op2) + /** Returns true iff the result of evaluating either `op1` or `op2` is true, * trying at the same time to keep the constraint as wide as possible. * E.g, if @@ -1237,8 +1258,14 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { * * Here, each precondition leads to a different constraint, and neither of * the two post-constraints subsumes the other. + * + * Note that to be complete when it comes to typechecking, we would instead need to backtrack + * and attempt to typecheck with the other constraint. + * + * Method name comes from the notion that we are keeping a constraint which is sufficient to satisfy + * one of subtyping relationships. */ - private def either(op1: => Boolean, op2: => Boolean): Boolean = { + private def sufficientEither(op1: => Boolean, op2: => Boolean): Boolean = { val preConstraint = constraint op1 && { val leftConstraint = constraint @@ -1252,6 +1279,90 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { } || op2 } + /** Returns true iff the result of evaluating either `op1` or `op2` is true, keeping the smaller constraint if any. + * E.g., if + * + * tp11 <:< tp12 = true with constraint c1 and GADT constraint g1 + * tp12 <:< tp22 = true with constraint c2 and GADT constraint g2 + * + * We keep: + * - (c1, g1) if c2 subsumes c1 and g2 subsumes g1 + * - (c2, g2) if c1 subsumes c2 and g1 subsumes g2 + * - neither constraint pair otherwise. + * + * Like [[sufficientEither]], this method is used to approximate a solution in one of the following cases: + * + * T1 & T2 <:< T3 + * T1 <:< T2 | T3 + * + * Unlike [[sufficientEither]], this method is used in GADTFlexible mode, when we are attempting to infer GADT + * constraints that necessarily follow from the subtyping relationship. For instance, if we have + * + * enum Expr[T] { + * case IntExpr(i: Int) extends Expr[Int] + * case StrExpr(s: String) extends Expr[String] + * } + * + * and `A` is an abstract type and we know that + * + * Expr[A] <: IntExpr | StrExpr + * + * (the case with &-type is analogous) then this may follow either from + * + * Expr[A] <: IntExpr or Expr[A] <: StrExpr + * + * Since we don't know which branch is true, we need to give up and not keep either constraint. OTOH, if one + * constraint pair is subsumed by the other, we know that it is necessary for both cases and therefore we can + * keep it. + * + * Like [[sufficientEither]], this method is not complete because sometimes, the necessary constraint + * is neither of the pairs. For instance, if + * + * g1 = { A = Int, B = String } + * g2 = { A = Int, B = Int } + * + * then the necessary constraint is { A = Int }, but correctly inferring that is, as far as we know, too expensive. + * + * Method name comes from the notion that we are keeping the constraint which is necessary to satisfy both + * subtyping relationships. + */ + private def necessaryEither(op1: => Boolean, op2: => Boolean): Boolean = { + val preConstraint = constraint + + val preGadt = ctx.gadt.fresh + // if GADTflexible mode is on, we expect to always have a ProperGadtConstraint + val pre = preGadt.asInstanceOf[ProperGadtConstraint] + if (op1) { + val leftConstraint = constraint + val leftGadt = ctx.gadt.fresh + constraint = preConstraint + ctx.gadt.restore(preGadt) + if (op2) { + if (pre.subsumes(leftGadt, ctx.gadt, preGadt) && subsumes(leftConstraint, constraint, preConstraint)) { + gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $leftGadt") + constr.println(i"CUT - prefer $constraint over $leftConstraint") + true + } else if (pre.subsumes(ctx.gadt, leftGadt, preGadt) && subsumes(constraint, leftConstraint, preConstraint)) { + gadts.println(i"GADT CUT - prefer $leftGadt over ${ctx.gadt}") + constr.println(i"CUT - prefer $leftConstraint over $constraint") + constraint = leftConstraint + ctx.gadt.restore(leftGadt) + true + } else { + gadts.println(i"GADT CUT - no constraint is preferable, reverting to $preGadt") + constr.println(i"CUT - no constraint is preferable, reverting to $preConstraint") + constraint = preConstraint + ctx.gadt.restore(preGadt) + true + } + } else { + constraint = leftConstraint + ctx.gadt.restore(leftGadt) + true + } + } else op2 + } + /** Does type `tp1` have a member with name `name` whose normalized type is a subtype of * the normalized type of the refinement `tp2`? * Normalization is as follows: If `tp2` contains a skolem to its refinement type, @@ -1364,7 +1475,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] { * Test that the resulting bounds are still satisfiable. */ private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { - val boundImprecise = if (isUpper) approx.high else approx.low + val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GADTflexible) && !frozenConstraint && !boundImprecise && { val tparam = tr.symbol gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index f41a710dcde5..664016b99f38 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -387,7 +387,7 @@ trait TypeOps { this: Context => // TODO: Make standalone object. val bound1 = massage(bound) if (bound1 ne bound) { if (checkCtx eq ctx) checkCtx = ctx.fresh.setFreshGADTBounds - if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addEmptyBounds(sym) + if (!checkCtx.gadt.contains(sym)) checkCtx.gadt.addToConstraint(sym) checkCtx.gadt.addBound(sym, bound1, fromBelow) typr.println("install GADT bound $bound1 for when checking F-bounded $sym") } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 49908bacdcd6..33efe736b3c8 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -3704,7 +3704,12 @@ object Types { // ----- Skolem types ----------------------------------------------- - /** A skolem type reference with underlying type `info` */ + /** A skolem type reference with underlying type `info`. + * + * For Dotty, a skolem type is a singleton type of some unknown value of type `info`. + * Note that care is needed when creating them, since not all types need to be inhabited. + * A skolem is equal to itself and no other type. + */ case class SkolemType(info: Type) extends UncachedProxyType with ValueType with SingletonType { override def underlying(implicit ctx: Context): Type = info def derivedSkolemType(info: Type)(implicit ctx: Context): SkolemType = @@ -3863,10 +3868,10 @@ object Types { def contextInfo(tp: Type): Type = tp match { case tp: TypeParamRef => val constraint = ctx.typerState.constraint - if (constraint.entry(tp).exists) constraint.fullBounds(tp) + if (constraint.entry(tp).exists) ctx.typeComparer.fullBounds(tp) else NoType case tp: TypeRef => - val bounds = ctx.gadt.bounds(tp.symbol) + val bounds = ctx.gadt.fullBounds(tp.symbol) if (bounds == null) NoType else bounds case tp: TypeVar => tp.underlying diff --git a/compiler/src/dotty/tools/dotc/printing/Formatting.scala b/compiler/src/dotty/tools/dotc/printing/Formatting.scala index 6b6a6845565a..408d369d84a4 100644 --- a/compiler/src/dotty/tools/dotc/printing/Formatting.scala +++ b/compiler/src/dotty/tools/dotc/printing/Formatting.scala @@ -170,7 +170,7 @@ object Formatting { case sym: Symbol => val info = if (ctx.gadt.contains(sym)) - sym.info & ctx.gadt.bounds(sym) + sym.info & ctx.gadt.fullBounds(sym) else sym.info s"is a ${ctx.printer.kindString(sym)}${sym.showExtendedLocation}${addendum("bounds", info)}" @@ -190,7 +190,7 @@ object Formatting { case param: TermParamRef => false case skolem: SkolemType => true case sym: Symbol => - ctx.gadt.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty + ctx.gadt.contains(sym) && ctx.gadt.fullBounds(sym) != TypeBounds.empty case _ => assert(false, "unreachable") false diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 9efd80c1424c..ffa4aa9a68a2 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -208,7 +208,7 @@ class PlainPrinter(_ctx: Context) extends Printer { else { val constr = ctx.typerState.constraint val bounds = - if (constr.contains(tp)) constr.fullBounds(tp.origin)(ctx.addMode(Mode.Printing)) + if (constr.contains(tp)) ctx.addMode(Mode.Printing).typeComparer.fullBounds(tp.origin) else TypeBounds.empty if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value) else if (ctx.settings.YshowVarBounds.value) "(" ~ toText(tp.origin) ~ "?" ~ toText(bounds) ~ ")" diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index e4d71a68488a..375d11bedc30 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -401,9 +401,9 @@ class TreeChecker extends Phase with SymTransformer { } } - override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = { + override def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { withPatSyms(tpd.patVars(tree.pat.asInstanceOf[tpd.Tree])) { - super.typedCase(tree, selType, pt, gadtSyms) + super.typedCase(tree, selType, pt) } } diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 03b67a7a91fe..bd07958cbb91 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -128,8 +128,8 @@ object ErrorReporting { case tp: TypeParamRef => constraint.entry(tp) match { case bounds: TypeBounds => - if (variance < 0) apply(constraint.fullUpperBound(tp)) - else if (variance > 0) apply(constraint.fullLowerBound(tp)) + if (variance < 0) apply(ctx.typeComparer.fullUpperBound(tp)) + else if (variance > 0) apply(ctx.typeComparer.fullLowerBound(tp)) else tp case NoType => tp case instType => apply(instType) diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 1fa53e79582a..61647f5a1f8b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -345,7 +345,7 @@ object Implicits { * @param level The level where the reference was found * @param tstate The typer state to be committed if this alternative is chosen */ - case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GADTMap) extends SearchResult with Showable + case class SearchSuccess(tree: Tree, ref: TermRef, level: Int)(val tstate: TyperState, val gstate: GadtConstraint) extends SearchResult with Showable /** A failed search */ case class SearchFailure(tree: Tree) extends SearchResult { @@ -397,21 +397,27 @@ object Implicits { * what was expected */ override def clarify(tp: Type)(implicit ctx: Context): Type = { - val map = new TypeMap { - def apply(t: Type): Type = t match { - case t: TypeParamRef => - constraint.entry(t) match { - case NoType => t - case bounds: TypeBounds => constraint.fullBounds(t) - case t1 => t1 - } - case t: TypeVar => - t.instanceOpt.orElse(apply(t.origin)) - case _ => - mapOver(t) + def replace(implicit ctx: Context): Type = { + val map = new TypeMap { + def apply(t: Type): Type = t match { + case t: TypeParamRef => + constraint.entry(t) match { + case NoType => t + case bounds: TypeBounds => ctx.typeComparer.fullBounds(t) + case t1 => t1 + } + case t: TypeVar => + t.instanceOpt.orElse(apply(t.origin)) + case _ => + mapOver(t) + } } + map(tp) } - map(tp) + + val ctx1 = ctx.fresh.setExploreTyperState() + ctx1.typerState.constraint = constraint + replace(ctx1) } def explanation(implicit ctx: Context): String = diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index eb0674802cf6..0902964edea7 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -184,8 +184,6 @@ object Inferencing { * * Invariant refinement can be assumed if `PatternType`'s class(es) are final or * case classes (because of `RefChecks#checkCaseClassInheritanceInvariant`). - * - * TODO: Update so that GADT symbols can be variant, and we special case final class types in patterns */ def constrainPatternType(tp: Type, pt: Type)(implicit ctx: Context): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { @@ -209,8 +207,9 @@ object Inferencing { } val widePt = if (ctx.scala2Mode || refinementIsInvariant(tp)) pt else widenVariantParams(pt) - trace(i"constraining pattern type $tp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { - tp <:< widePt + val narrowTp = SkolemType(tp) + trace(i"constraining pattern type $narrowTp <:< $widePt", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") { + narrowTp <:< widePt } } @@ -263,7 +262,7 @@ object Inferencing { * 0 if unconstrained, or constraint is from below and above. */ private def instDirection(param: TypeParamRef)(implicit ctx: Context): Int = { - val constrained = ctx.typerState.constraint.fullBounds(param) + val constrained = ctx.typeComparer.fullBounds(param) val original = param.binder.paramInfos(param.paramNum) val cmp = ctx.typeComparer val approxBelow = @@ -298,17 +297,22 @@ object Inferencing { if (v == 1) tvar.instantiate(fromBelow = false) else if (v == -1) tvar.instantiate(fromBelow = true) else { - val bounds = ctx.typerState.constraint.fullBounds(tvar.origin) + val bounds = ctx.typeComparer.fullBounds(tvar.origin) if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x) tvar.instantiate(fromBelow = false) else { - val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span) + // We do not add the created symbols to GADT constraint immediately, since they may have inter-dependencies. + // Instead, we simultaneously add them later on. + val wildCard = ctx.newPatternBoundSymbol(UniqueName.fresh(tvar.origin.paramName), bounds, span, addToGadt = false) tvar.instantiateWith(wildCard.typeRef) patternBound += wildCard } } } - patternBound.toList + val res = patternBound.toList + // We add the created symbols to GADT constraint here. + if (res.nonEmpty) ctx.gadt.addToConstraint(res) + res } type VarianceMap = SimpleIdentityMap[TypeVar, Integer] diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 193a92f0000a..f1ff0070469c 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -534,6 +534,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { /** An extractor for terms equivalent to `new C(args)`, returning the class `C`, * a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can * follow a reference to an inline value binding to its right hand side. + * * @return optionally, a triple consisting of * - the class `C` * - the arguments `args` @@ -729,7 +730,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = { val isImplicit = scrutinee.isEmpty - val gadtSyms = typer.gadtSyms(scrutType) /** Try to match pattern `pat` against scrutinee reference `scrut`. If successful add * bindings for variables bound in this pattern to `caseBindingMap`. @@ -821,11 +821,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } def registerAsGadtSyms(typeBinds: TypeBindsMap)(implicit ctx: Context): Unit = - typeBinds.foreachBinding { case (sym, _) => - val TypeBounds(lo, hi) = sym.info.bounds - ctx.gadt.addBound(sym, lo, isUpper = false) - ctx.gadt.addBound(sym, hi, isUpper = true) - } + if (typeBinds.size > 0) ctx.gadt.addToConstraint(typeBinds.keys) pat match { case Typed(pat1, tpt) => @@ -920,7 +916,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } if (!isImplicit) caseBindingMap += ((NoSymbol, scrutineeBinding)) - val gadtCtx = typer.gadtContext(gadtSyms).addMode(Mode.GADTflexible) + val gadtCtx = ctx.fresh.setFreshGADTBounds.addMode(Mode.GADTflexible) if (reducePattern(caseBindingMap, scrutineeSym.termRef, cdef.pat)(gadtCtx)) { val (caseBindings, from, to) = substBindings(caseBindingMap.toList, mutable.ListBuffer(), Nil, Nil) val guardOK = cdef.guard.isEmpty || { diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index bbf1541fe222..ba38ea9b75bc 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1335,8 +1335,15 @@ class Namer { typer: Typer => // it would be erased to BoxedUnit. def dealiasIfUnit(tp: Type) = if (tp.isRef(defn.UnitClass)) defn.UnitType else tp - var rhsCtx = ctx.addMode(Mode.InferringReturnType) + var rhsCtx = ctx.fresh.addMode(Mode.InferringReturnType) if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) + if (typeParams.nonEmpty) { + // we'll be typing an expression from a polymorphic definition's body, + // so we must allow constraining its type parameters + // compare with typedDefDef, see tests/pos/gadt-inference.scala + rhsCtx.setFreshGADTBounds + rhsCtx.gadt.addToConstraint(typeParams) + } def rhsType = typedAheadExpr(mdef.rhs, (inherited orElse rhsProto).widenExpr)(rhsCtx).tpe // Approximate a type `tp` with a type that does not contain skolem types. diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index f07563ec49f3..6faac5b0614b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1047,37 +1047,8 @@ class Typer extends Namer assignType(cpy.Match(tree)(sel, cases1), sel, cases1) } - /** gadtSyms = "all type parameters of enclosing methods that appear - * non-variantly in the selector type" todo: should typevars - * which appear with variances +1 and -1 (in different - * places) be considered as well? - */ - def gadtSyms(selType: Type)(implicit ctx: Context): Set[Symbol] = trace(i"GADT syms of $selType", gadts) { - val accu = new TypeAccumulator[Set[Symbol]] { - def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = { - val tsyms1 = t match { - case tr: TypeRef if (tr.symbol is TypeParam) && tr.symbol.owner.isTerm && variance == 0 => - tsyms + tr.symbol - case _ => - tsyms - } - foldOver(tsyms1, t) - } - } - accu(Set.empty, selType) - } - - /** Context with fresh GADT bounds for all gadtSyms */ - def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = { - val gadtCtx = ctx.fresh.setFreshGADTBounds - for (sym <- gadtSyms) - if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym) - gadtCtx - } - def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context): List[CaseDef] = { - val gadts = gadtSyms(selType) - cases.mapconserve(typedCase(_, selType, pt, gadts)) + cases.mapconserve(typedCase(_, selType, pt)) } /** - strip all instantiated TypeVars from pattern types. @@ -1096,7 +1067,7 @@ class Typer extends Namer if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(sym) else ctx.error(new DuplicateBind(b, cdef), b.sourcePos) if (!ctx.isAfterTyper) { - val bounds = ctx.gadt.bounds(sym) + val bounds = ctx.gadt.fullBounds(sym) if (bounds != null) sym.info = bounds } b @@ -1105,9 +1076,9 @@ class Typer extends Namer } /** Type a case. */ - def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") { + def typedCase(tree: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = track("typedCase") { val originalCtx = ctx - val gadtCtx = gadtContext(gadtSyms) + val gadtCtx: Context = ctx.fresh.setFreshGADTBounds def caseRest(pat: Tree)(implicit ctx: Context) = { val pat1 = indexPattern(tree).transform(pat) @@ -1132,8 +1103,6 @@ class Typer extends Namer def typedTypeCase(cdef: untpd.CaseDef, selType: Type, pt: Type)(implicit ctx: Context): CaseDef = { def caseRest(implicit ctx: Context) = { val pat1 = checkSimpleKinded(typedType(cdef.pat)(ctx.addMode(Mode.Pattern))) - if (!ctx.isAfterTyper) - constrainPatternType(pat1.tpe, selType)(ctx.addMode(Mode.GADTflexible)) val pat2 = indexPattern(cdef).transform(pat1) val body1 = typedType(cdef.body, pt) assignType(cpy.CaseDef(cdef)(pat2, EmptyTree, body1), pat2, body1) @@ -1537,19 +1506,28 @@ class Typer extends Namer if (sym is ImplicitOrImplied) checkImplicitConversionDefOK(sym) val tpt1 = checkSimpleKinded(typedType(tpt)) - var rhsCtx = ctx - if (sym.isConstructor && !sym.isPrimaryConstructor && tparams1.nonEmpty) { - // for secondary constructors we need a context that "knows" - // that their type parameters are aliases of the class type parameters. - // See pos/i941.scala - rhsCtx = ctx.fresh.setFreshGADTBounds - (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => - val tr = tparam.typeRef - rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) - rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) + val rhsCtx = ctx.fresh + if (tparams1.nonEmpty) { + rhsCtx.setFreshGADTBounds + if (!sym.isConstructor) { + // we're typing a polymorphic definition's body, + // so we allow constraining all of its type parameters + // constructors are an exception as we don't allow constraining type params of classes + rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol)) + } else if (!sym.isPrimaryConstructor) { + // otherwise, for secondary constructors we need a context that "knows" + // that their type parameters are aliases of the class type parameters. + // See pos/i941.scala + rhsCtx.gadt.addToConstraint(tparams1.map(_.symbol)) + (tparams1, sym.owner.typeParams).zipped.foreach { (tdef, tparam) => + val tr = tparam.typeRef + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = false) + rhsCtx.gadt.addBound(tdef.symbol, tr, isUpper = true) + } } } - if (sym.isInlineMethod) rhsCtx = rhsCtx.addMode(Mode.InlineableBody) + + if (sym.isInlineMethod) rhsCtx.addMode(Mode.InlineableBody) val rhs1 = typedExpr(ddef.rhs, tpt1.tpe.widenExpr)(rhsCtx) if (sym.isInlineMethod) { diff --git a/tests/neg/classOf.check b/tests/neg/classOf.check index 7c761b8af5be..b8416e4007e3 100644 --- a/tests/neg/classOf.check +++ b/tests/neg/classOf.check @@ -2,5 +2,7 @@ Test.C{I = String} is not a class type [116..117] in classOf.scala T is not a class type + +where: T is a type in method f2 with bounds <: String [72..73] in classOf.scala T is not a class type diff --git a/tests/neg/creative-gadt-constraints.scala b/tests/neg/creative-gadt-constraints.scala new file mode 100644 index 000000000000..a8869de5f75d --- /dev/null +++ b/tests/neg/creative-gadt-constraints.scala @@ -0,0 +1,66 @@ +object buffer { + object EssaInt { + def unapply(i: Int): Some[Int] = Some(i) + } + + case class Inv[T](t: T) + + enum EQ[A, B] { case Refl[T]() extends EQ[T, T] } + enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] } // A <: B + + def test_eq1[A, B](eq: EQ[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) | Sko(Int) + eq match { case EQ.Refl() => // a = b + val success: A = b + val fail: A = 0 // error + 0 // error + } + } + } + + def test_eq2[A, B](eq: EQ[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) + eq match { case EQ.Refl() => // a = b + val success: A = b + val fail: A = 0 // error + 0 // error + } + } + } + + def test_sub1[A, B](sub: SUB[A, B], a: A, b: B): B = + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int) + sub match { case SUB.Refl() => // b >: a + val success: B = a + val fail: A = 0 // error + 0 // error + } + } + } + + def test_sub2[A, B](sub: SUB[A, B], a: A, b: B): B = + Inv(a) match { case Inv(_: Int) => // a >: Sko(Int) + Inv(b) match { case Inv(_: Int) => // b >: Sko(Int) | Sko(Int) + sub match { case SUB.Refl() => // b >: a + val success: B = a + val fail: A = 0 // error + 0 // error + } + } + } + + + def test_sub_eq[A, B, C](sub: SUB[A|B, C], eqA: EQ[A, 5], eqB: EQ[B, 6]): C = + sub match { case SUB.Refl() => // C >: A | B + eqA match { case EQ.Refl() => // A = 5 + eqB match { case EQ.Refl() => // B = 6 + val fail1: A = 0 // error + val fail2: B = 0 // error + 0 // error + } + } + } +} diff --git a/tests/neg/gadt-no-approx.scala b/tests/neg/gadt-no-approx.scala new file mode 100644 index 000000000000..eef0d82cba21 --- /dev/null +++ b/tests/neg/gadt-no-approx.scala @@ -0,0 +1,10 @@ +object `gadt-no-approx` { + def fo[U](u: U): U = + (0 : Int) match { + case _: u.type => + val i: Int = (??? : U) // error + // potentially could compile + // val i2: Int = u + u + } +} diff --git a/tests/neg/int-extractor.scala b/tests/neg/int-extractor.scala new file mode 100644 index 000000000000..8534c5a1bc00 --- /dev/null +++ b/tests/neg/int-extractor.scala @@ -0,0 +1,31 @@ +object Test { + object EssaInt { + def unapply(i: Int): Some[Int] = Some(i) + } + + def foo1[T](t: T): T = t match { + case EssaInt(_) => + 0 // error + } + + def foo2[T](t: T): T = t match { + case EssaInt(_) => t match { + case EssaInt(_) => + 0 // error + } + } + + case class Inv[T](t: T) + + def bar1[T](t: T): T = Inv(t) match { + case Inv(EssaInt(_)) => + 0 // error + } + + def bar2[T](t: T): T = t match { + case Inv(EssaInt(_)) => t match { + case Inv(EssaInt(_)) => + 0 // error + } + } +} diff --git a/tests/neg/invariant-gadt.scala b/tests/neg/invariant-gadt.scala new file mode 100644 index 000000000000..ac335f57743f --- /dev/null +++ b/tests/neg/invariant-gadt.scala @@ -0,0 +1,27 @@ +object `invariant-gadt` { + case class Invariant[T](value: T) + + def unsound0[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => + (0: Any) // error + } + + def unsound1[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => + 0 // error + } + + def unsound2[T](t: T): T = Invariant(t) match { + case Invariant(value) => value match { + case _: Int => + 0 // error + } + } + + def unsoundTwice[T](t: T): T = Invariant(t) match { + case Invariant(_: Int) => Invariant(t) match { + case Invariant(_: Int) => + 0 // error + } + } +} diff --git a/tests/neg/typeclass-derivation2.scala b/tests/neg/typeclass-derivation2.scala index 33c64494e9c5..ddb6517fb869 100644 --- a/tests/neg/typeclass-derivation2.scala +++ b/tests/neg/typeclass-derivation2.scala @@ -111,6 +111,13 @@ object TypeLevel { * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. */ abstract class Shaped[T, S <: Shape] extends Reflected[T] + + // substitute for erasedValue that allows precise matching + final abstract class Type[-A, +B] + type Subtype[t] = Type[_, t] + type Supertype[t] = Type[t, _] + type Exactly[t] = Type[t, t] + erased def typeOf[T]: Type[T, T] = ??? } // An algebraic datatype @@ -203,7 +210,7 @@ trait Show[T] { def show(x: T): String } object Show { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryShow[T](x: T): String = implicit match { @@ -229,9 +236,14 @@ object Show { inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => showCase[T, elems](r, x) - case _ => showCases[T, alts1](r, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _ => + error("invalid call to showCases: one of Alts is not a subtype of T") } case _: Unit => throw new MatchError(x) diff --git a/tests/pos/gadt-accumulatable.scala b/tests/pos/gadt-accumulatable.scala new file mode 100644 index 000000000000..ce4cf347538d --- /dev/null +++ b/tests/pos/gadt-accumulatable.scala @@ -0,0 +1,37 @@ +object `gadt-accumulatable` { + sealed abstract class Or[+G,+B] extends Product with Serializable + final case class Good[+G](g: G) extends Or[G,Nothing] + final case class Bad[+B](b: B) extends Or[Nothing,B] + + sealed trait Validation[+E] extends Product with Serializable + case object Pass extends Validation[Nothing] + case class Fail[E](error: E) extends Validation[E] + + sealed abstract class Every[+T] protected (underlying: Vector[T]) extends /*PartialFunction[Int, T] with*/ Product with Serializable + final case class One[+T](loneElement: T) extends Every[T](Vector(loneElement)) + final case class Many[+T](firstElement: T, secondElement: T, otherElements: T*) extends Every[T](firstElement +: secondElement +: Vector(otherElements: _*)) + + class Accumulatable[G, ERR, EVERY[_]] { } + + def convertOrToAccumulatable[G, ERR, EVERY[b] <: Every[b]](accumulatable: G Or EVERY[ERR]): Accumulatable[G, ERR, EVERY] = { + new Accumulatable[G, ERR, EVERY] { + def when[OTHERERR >: ERR](validations: (G => Validation[OTHERERR])*): G Or Every[OTHERERR] = { + accumulatable match { + case Good(g) => + val results = validations flatMap (_(g) match { case Fail(x) => val z: OTHERERR = x; Seq(x); case Pass => Seq.empty}) + results.length match { + case 0 => Good(g) + case 1 => Bad(One(results.head)) + case _ => + val first = results.head + val tail = results.tail + val second = tail.head + val rest = tail.tail + Bad(Many(first, second, rest: _*)) + } + case Bad(myBad) => Bad(myBad) + } + } + } + } +} diff --git a/tests/pos/gadt-all-params.scala b/tests/pos/gadt-all-params.scala new file mode 100644 index 000000000000..b5d7baecc283 --- /dev/null +++ b/tests/pos/gadt-all-params.scala @@ -0,0 +1,9 @@ +object `gadt-all-params` { + enum Expr[T] { + case UnitLit extends Expr[Unit] + } + + def foo[T >: TT <: TT, TT](e: Expr[T]): T = e match { + case Expr.UnitLit => () + } +} diff --git a/tests/pos/gadt-inference.scala b/tests/pos/gadt-inference.scala new file mode 100644 index 000000000000..e625e4823dc0 --- /dev/null +++ b/tests/pos/gadt-inference.scala @@ -0,0 +1,44 @@ +object `gadt-inference` { + enum Expr[T] { + case StrLit(s: String) extends Expr[String] + case IntLit(i: Int) extends Expr[Int] + } + import Expr._ + + def eval[T](e: Expr[T]) = + e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + + def nested[T](o: Option[Expr[T]]) = + o match { + case Some(e) => e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + case None => ??? + } + + def local[T](e: Expr[T]) = { + def eval[T](e: Expr[T]) = + e match { + case StrLit(s) => + val a = (??? : T) : String + s : T + case IntLit(i) => + val a = (??? : T) : Int + i : T + } + + eval(e) : T + } +} diff --git a/tests/pos/precise-pattern-type.scala b/tests/pos/precise-pattern-type.scala new file mode 100644 index 000000000000..856672fafbf2 --- /dev/null +++ b/tests/pos/precise-pattern-type.scala @@ -0,0 +1,16 @@ +object `precise-pattern-type` { + class Type { + def isType: Boolean = true + } + + class Tree[-T >: Null] { + def tpe: T @annotation.unchecked.uncheckedVariance = ??? + } + + case class Select[-T >: Null](qual: Tree[T]) extends Tree[T] + + def test[T <: Tree[Type]](tree: T) = tree match { + case Select(q) => + q.tpe.isType + } +} diff --git a/tests/run-macros/tasty-extractors-3.check b/tests/run-macros/tasty-extractors-3.check index 2e3b9f23e983..35c88a7598f5 100644 --- a/tests/run-macros/tasty-extractors-3.check +++ b/tests/run-macros/tasty-extractors-3.check @@ -10,6 +10,8 @@ Type.SymRef(IsClassDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDe Type.SymRef(IsTypeDefSymbol(), NoPrefix()) +Type.SymRef(IsTypeDefSymbol(), NoPrefix()) + TypeBounds(Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix())))), Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix()))))) Type.SymRef(IsClassDefSymbol(), Type.SymRef(IsPackageDefSymbol(), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<>), NoPrefix())))) diff --git a/tests/run/typeclass-derivation2.scala b/tests/run/typeclass-derivation2.scala index 8ac7cec4487c..f8812b461d48 100644 --- a/tests/run/typeclass-derivation2.scala +++ b/tests/run/typeclass-derivation2.scala @@ -113,6 +113,13 @@ object TypeLevel { * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. */ abstract class Shaped[T, S <: Shape] extends Reflected[T] + + // substitute for erasedValue that allows precise matching + final abstract class Type[-A, +B] + type Subtype[t] = Type[_, t] + type Supertype[t] = Type[t, _] + type Exactly[t] = Type[t, t] + erased def typeOf[T]: Type[T, T] = ??? } // An algebraic datatype @@ -217,7 +224,7 @@ trait Eq[T] { } object Eq { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryEql[T](x: T, y: T) = implicit match { @@ -239,8 +246,13 @@ object Eq { inline def eqlCases[T, Alts <: Tuple](xm: Mirror, ym: Mirror, ordinal: Int, n: Int): Boolean = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - if (n == ordinal) eqlElems[elems](xm, ym, 0) - else eqlCases[T, alts1](xm, ym, ordinal, n + 1) + inline typeOf[alt] match { + case _: Subtype[T] => + if (n == ordinal) eqlElems[elems](xm, ym, 0) + else eqlCases[T, alts1](xm, ym, ordinal, n + 1) + case _ => + error("invalid call to eqlCases: one of Alts is not a subtype of T") + } case _: Unit => false } @@ -271,7 +283,7 @@ trait Pickler[T] { } object Pickler { - import scala.compiletime.{erasedValue, constValue} + import scala.compiletime.{erasedValue, constValue, error} import TypeLevel._ def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1) @@ -294,12 +306,17 @@ object Pickler { inline def pickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T, n: Int): Unit = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => - buf += n - pickleCase[T, elems](r, buf, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => + buf += n + pickleCase[T, elems](r, buf, x) + case _ => + pickleCases[T, alts1](r, buf, x, n + 1) + } case _ => - pickleCases[T, alts1](r, buf, x, n + 1) + error("invalid pickleCases call: one of Alts is not a subtype of T") } case _: Unit => } @@ -362,7 +379,7 @@ trait Show[T] { def show(x: T): String } object Show { - import scala.compiletime.erasedValue + import scala.compiletime.{erasedValue, error} import TypeLevel._ inline def tryShow[T](x: T): String = implicit match { @@ -388,9 +405,15 @@ object Show { inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = inline erasedValue[Alts] match { case _: (Shape.Case[alt, elems] *: alts1) => - x match { - case x: `alt` => showCase[T, elems](r, x) - case _ => showCases[T, alts1](r, x) + inline typeOf[alt] match { + case _: Subtype[T] => + x match { + case x: `alt` => + showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _ => + error("invalid call to showCases: one of Alts is not a subtype of T") } case _: Unit => throw new MatchError(x)