diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index d8e1c5276ab6..f6fdd5eff28b 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -11,12 +11,24 @@ import collection.mutable import printing._ import scala.annotation.internal.sharable +import Denotations.{Denotation, SingleDenotation} +import SymDenotations.NoDenotation + +/** Types that represent a path. Can either be a TermRef or a SkolemType. */ +type PathType = TermRef | SkolemType /** Represents GADT constraints currently in scope */ sealed abstract class GadtConstraint extends Showable { /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ def bounds(sym: Symbol)(using Context): TypeBounds | Null + /** Immediate bounds of a path-dependent type. + * This variant of bounds will ONLY try to retrieve path-dependent GADT bounds. */ + def bounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null + + /** Immediate bounds of path-dependent type tp. */ + def bounds(tp: TypeRef)(using Context): TypeBounds | Null + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. * * @note this performs subtype checks between ordered symbols. @@ -24,9 +36,18 @@ sealed abstract class GadtConstraint extends Showable { */ def fullBounds(sym: Symbol)(using Context): TypeBounds | Null + /** Full bounds of path dependent type path.sym. */ + def fullBounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null + + /** Full bounds of path-dependent type tp. */ + def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null + /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean + /** Is `tp1` ordered to be less than `tp2`? */ + def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean + /** Add symbols to constraint, correctly handling inter-dependencies. * * @see [[ConstraintHandling.addToConstraint]] @@ -34,15 +55,63 @@ sealed abstract class GadtConstraint extends Showable { def addToConstraint(syms: List[Symbol])(using Context): Boolean def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil) + /** Add path to constraint, registering all its abstract type members. */ + def addToConstraint(path: PathType)(using Context): Boolean + /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + /** Further constrain a path-dependent type already present in the constraint. */ + def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + + /** Record the aliasing relationship between two singleton types. */ + def recordPathAliasing(p: PathType, q: PathType)(using Context): Unit + + /** Check whether two paths are equivalent via path aliasing. */ + def isAliasingPath(p: PathType, q: PathType): Boolean + + /** Scrutinee path of the current pattern matching that is being typed. + * + * See `constrainTypeMembers` in `PatternTypeConstrainer`. + */ + def scrutineePath: TermRef | Null + + /** Reset scrutinee path to null. */ + def resetScrutineePath(): Unit + + /** Set the scrutinee path. */ + def withScrutineePath[T](path: TermRef | Null)(op: => T): T + + /** Supply the real pattern path. + * + * See `constrainTypeMembers` in `PatternTypeConstrainer`. + */ + def supplyPatternPath(path: TermRef)(using Context): Unit + + /** Create a skolem type for pattern and save it in the constraint handler. + * + * See `constrainTypeMembers` in `PatternTypeConstrainer`. + */ + def createPatternSkolem(pat: Type): SkolemType + /** Is the symbol registered in the constraint? * * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. */ def contains(sym: Symbol)(using Context): Boolean + /** Checks whether a path is registered. */ + def contains(path: PathType)(using Context): Boolean + + /** Checks whether a path-dependent type is registered in the handler. */ + def contains(path: PathType, sym: Symbol)(using Context): Boolean + + /** Checks whether a given path-dependent type is constrainable. */ + def isConstrainablePDT(path: PathType, sym: Symbol)(using Context): Boolean + + /** Get all type members registered in the constraint handler for this path. */ + def registeredTypeMembers(path: PathType): List[Symbol] + /** GADT constraint narrows bounds of at least one variable */ def isNarrowing: Boolean @@ -63,7 +132,12 @@ final class ProperGadtConstraint private( private var myConstraint: Constraint, private var mapping: SimpleIdentityMap[Symbol, TypeVar], private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private var wasConstrained: Boolean + private var pathDepMapping: SimpleIdentityMap[PathType, SimpleIdentityMap[Symbol, TypeVar]], + private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], + private var wasConstrained: Boolean, + private var myScrutineePath: TermRef | Null, + private var pathAliasingMapping: SimpleIdentityMap[PathType, PathType], + private var myPatternSkolem: SkolemType | Null, ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -71,23 +145,265 @@ final class ProperGadtConstraint private( myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty), mapping = SimpleIdentityMap.empty, reverseMapping = SimpleIdentityMap.empty, - wasConstrained = false + pathDepMapping = SimpleIdentityMap.empty, + pathDepReverseMapping = SimpleIdentityMap.empty, + wasConstrained = false, + myScrutineePath = null, + pathAliasingMapping = SimpleIdentityMap.empty, + myPatternSkolem = null, ) - /** Exposes ConstraintHandling.subsumes */ - def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { - def extractConstraint(g: GadtConstraint) = g match { - case s: ProperGadtConstraint => s.constraint - case EmptyGadtConstraint => OrderingConstraint.empty + /** Whether `left` subsumes `right`? + * + * `left` and `right` branches from `pre` with different constraint reasoning + * performed. During this, new path-dependent types could be registered in `left` + * and `right`. + * + * This function considers the cases where both `left` and `right` register a + * new path-depepdent type p.T. The problem is that in this case p.T will have + * two different internal representation in `left` and `right`: + * - p.T registered in `left` and has the internal representation T(param)$1 + * - p.T registered in `right` and has the internal representation T(param)$2 + * In `subsumes`, we have to recognize the fact that both T(param)$1 and T(param)$2 + * represents the same path-dependent type, to give the correct result. + */ + def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = + def checkSubsumes(left: ProperGadtConstraint, right: ProperGadtConstraint, pre: ProperGadtConstraint): Boolean = { + def getRightToLeftMapping: Option[TypeParamRef => TypeParamRef] = { + val preParams = pre.constraint.domainParams.toSet + val mapping = { + var res: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + var hasNull: Boolean = false + + right.constraint.domainParams.foreach { p2 => + left.tvarOf(right.externalize(p2)) match { + case null => + hasNull = true + case tv: TypeVar => + res = res.updated(p2, tv.origin) + } + } + + if hasNull then None else Some(res) + } + + mapping map { mapping => + def func(p2: TypeParamRef) = + if pre.constraint.domainParams contains p2 then p2 + else mapping(p2).nn + func + } + } + + getRightToLeftMapping map { rightToLeft => + def checkParam(p2: TypeParamRef) = + val p1 = rightToLeft(p2) + left.constraint.entry(p1).exists + && right.constraint.upper(p1).map(rightToLeft).forall(left.constraint.isLess(p1, _)) + && isSubTypeWhenFrozen(left.constraint.nonParamBounds(p1), right.constraint.nonParamBounds(p2)) + def todos: Set[TypeParamRef] = + right.constraint.domainParams.toSet ++ pre.constraint.domainParams + + todos.forall(checkParam) + } getOrElse false + } + + (left, right, pre) match { + case (left: ProperGadtConstraint, right: ProperGadtConstraint, pre: ProperGadtConstraint) => + checkSubsumes(left, right, pre) + case (_, EmptyGadtConstraint, _) => true + case (EmptyGadtConstraint, _, _) => false + case (_, _, EmptyGadtConstraint) => false } - subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) - } override protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type = // GADT constraints never involve wildcards and are not propagated outside // the case where they're valid, so no approximating is needed. rawBound + /** Whether type members of the given path is constrainable? + * + * Package's and module's type members will not be constrained. + */ + private def isConstrainablePath(path: Type)(using Context): Boolean = path match + case path: TermRef + if !path.symbol.is(Flags.Package) + && !path.symbol.is(Flags.Module) + && !path.classSymbol.is(Flags.Package) + && !path.classSymbol.is(Flags.Module) + => true + case _: SkolemType + if !path.classSymbol.is(Flags.Package) + && !path.classSymbol.is(Flags.Module) + => true + case _ => false + + /** Check whether a type member is constrainable based on its denotation. + * + * A type member is considered constrainable if it is abstract, is not + * an opaque type, is not a class and is non-private. + */ + private def isConstrainableDenot(denot: Denotation)(using Context): Boolean = + denot.symbol.is(Flags.Deferred) + && !denot.symbol.is(Flags.Opaque) + && !denot.symbol.isClass + && !denot.isInstanceOf[NoDenotation.type] + + /** Find all constrainable type member denotations of the given type. + * + * Note that we return denotation here, since the bounds of the type member + * depend on the context (e.g. applied type parameters). + */ + private def constrainableTypeMembers(tp: Type)(using Context): List[Denotation] = + tp.typeMembers.toList filter { denot => + val denot1 = tp.nonPrivateMember(denot.name) + isConstrainableDenot(denot1) + } + + /** Check whether a type member of a path is constrainable. */ + private def isConstrainableTypeMember(path: PathType, sym: Symbol)(using Context): Boolean = + val mbr = path.nonPrivateMember(sym.name) + mbr.isInstanceOf[SingleDenotation] && { + val denot1 = path.nonPrivateMember(mbr.name) + isConstrainableDenot(denot1) + } + + /** Check whether a path-dependent type is constrainable. + * + * A path-dependent type p.A is constrainable if its path p and the type member A is + * constrainable. + */ + override def isConstrainablePDT(path: PathType, sym: Symbol)(using Context): Boolean = + isConstrainablePath(path) && isConstrainableTypeMember(path, sym) + + /** Get the internal type variable of the path-dependent type. Return null if it + * is not registered. + */ + private def tvarOf(path: PathType, sym: Symbol)(using Context): TypeVar | Null = + pathDepMapping(path) match + case null => null + case innerMapping => innerMapping.nn(sym) + + /** Try to retrieve type variable for some TypeRef. + * Both type parameters and path-dependent types are considered. + */ + private def tvarOf(tpr: TypeRef)(using Context): TypeVar | Null = + mapping(tpr.symbol) match + case null => + tpr match + case TypeRef(p: PathType, _) => tvarOf(p, tpr.symbol) + case _ => null + case tv: TypeVar => tv + + private def tvarOf(tp: Type)(using Context): TypeVar | Null = + tp match + case tp: TypeRef => tvarOf(tp) + case _ => null + + /** Register all constrainable path-dependent types addressed from the path. + * Returns whether the registration succeeds. It also checks whether the path + * itself is constrainable. + */ + override def addToConstraint(path: PathType)(using Context): Boolean = isConstrainablePath(path) && { + import NameKinds.DepParamName + val pathType = path.widen + val typeMembers = constrainableTypeMembers(path).filterNot(_.symbol eq NoSymbol) + + gadts.println { + val sb = new mutable.StringBuilder() + sb ++= i"* trying to add $path into constraint ...\n" + sb ++= i"** path.widen = $pathType\n" + sb ++= i"** type members =\n${debugShowTypeMembers(typeMembers)}\n" + sb.result() + } + + typeMembers.nonEmpty && { + val typeMemberSymbols: List[Symbol] = typeMembers map { x => x.symbol } + + val poly1 = PolyType(typeMembers map { d => DepParamName.fresh(d.name.toTypeName) })( + pt => typeMembers map { typeMember => + def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { + def loop(tp: Type): Type = tp match + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndOrType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp @ TypeRef(prefix, _) if prefix eq path => + typeMemberSymbols indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp @ TypeRef(_: RecThis, _) => + typeMemberSymbols indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp: TypeRef => + tvarOf(tp) match { + case tv: TypeVar => + val tp = stripInternalTypeVar(tv) + tp.match { + case tv1: TypeVar => stripTypeVarWhenDependent(tv1) + case tp => tp + } + case null => tp + } + case tp => tp + + loop(tp) + } + + val tb = typeMember.info.bounds + + def stripLazyRef(tp: Type): Type = tp match + case tp @ RefinedType(parent, name, tb) => + tp.derivedRefinedType(stripLazyRef(parent), name, stripLazyRef(tb)) + case tp: RecType => + tp.derivedRecType(stripLazyRef(tp.parent)) + case tb: TypeBounds => + tb.derivedTypeBounds(stripLazyRef(tb.lo), stripLazyRef(tb.hi)) + case ref: LazyRef => + ref.stripLazyRef + case _ => tp + + val tb1: TypeBounds = stripLazyRef(tb).asInstanceOf + + tb1.derivedTypeBounds( + lo = substDependentSyms(tb1.lo, isUpper = false), + hi = substDependentSyms(tb1.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = typeMemberSymbols lazyZip poly1.paramRefs map { (sym, paramRef) => + val tv = TypeVar(paramRef, creatorState = null) + + val externalType = TypeRef(path, sym) + pathDepMapping = pathDepMapping.updated(path, { + val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match + case null => SimpleIdentityMap.empty + case m => m.nn + + old.updated(sym, tv) + }) + pathDepReverseMapping = pathDepReverseMapping.updated(tv.origin, externalType) + + tv + } + + addToConstraint(poly1, tvars) + .showing(i"added to constraint: [$poly1] $path, result = $result\n$debugBoundsDescription", gadts) + } + } + + private def debugShowTypeMembers(typeMembers: List[Denotation])(using Context): String = + val buf = new mutable.StringBuilder + buf ++= "{\n" + typeMembers foreach { denot => + buf ++= i" ${denot.symbol}: ${denot.info.bounds} [ ${denot.info.bounds.toString} ]\n" + } + buf ++= "}" + buf.result + override def addToConstraint(params: List[Symbol])(using Context): Boolean = { import NameKinds.DepParamName @@ -107,7 +423,12 @@ final class ProperGadtConstraint private( params.indexOf(tp.symbol) match { case -1 => mapping(tp.symbol) match { - case tv: TypeVar => tv.origin + case tv: TypeVar => + val tp = stripInternalTypeVar(tv) + tp.match { + case tv1: TypeVar => stripTypeVarWhenDependent(tv1) + case tp => tp + } case null => tp } case i => pt.paramRefs(i) @@ -137,24 +458,33 @@ final class ProperGadtConstraint private( .showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts) } - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = constraint.instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } + @annotation.tailrec private def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = constraint.instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + private def stripTypeVarWhenDependent(tv: TypeVar): TypeParamRef | TypeVar = + val tpr = tv.origin + if constraint.contains(tpr) then + tpr + else + tv + end stripTypeVarWhenDependent + + private def addBoundForTvar(tvar: TypeVar, bound: Type, isUpper: Boolean, typeRepr: String)(using Context): Boolean = { - val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { + val symTvar: TypeVar = stripInternalTypeVar(tvar) match { case tv: TypeVar => tv case inst => - gadts.println(i"instantiated: $sym -> $inst") + gadts.println(i"instantiated: $typeRepr -> $inst") return if (isUpper) isSub(inst, bound) else isSub(bound, inst) } val internalizedBound = bound match { - case nt: NamedType => - val ntTvar = mapping(nt.symbol) + case nt: TypeRef => + val ntTvar = tvarOf(nt) if (ntTvar != null) stripInternalTypeVar(ntTvar) else bound case _ => bound } @@ -171,16 +501,62 @@ final class ProperGadtConstraint private( gadts.println { val descr = if (isUpper) "upper" else "lower" val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $result" + i"adding $descr bound $typeRepr $op $bound = $result" } if constraint ne saved then wasConstrained = true result } + override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + val tvar = tvarOrError(path, sym) + val typeRepr = TypeRef(path, sym).show + addBoundForTvar(tvar, bound, isUpper, typeRepr) + } + + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + val tvar = tvarOrError(sym) + val typeRepr = sym.typeRef.show + addBoundForTvar(tvar, bound, isUpper, typeRepr) + } + + private def lookupPath(p: PathType): PathType | Null = + def recur(p: PathType): PathType | Null = pathAliasingMapping(p) match + case null => null + case q: PathType if q eq p => q + case q: PathType => + recur(q) + + recur(p) + + override def recordPathAliasing(p: PathType, q: PathType)(using Context): Unit = + val pRep: PathType | Null = lookupPath(p) + val qRep: PathType | Null = lookupPath(q) + + val newRep = (pRep, qRep) match + case (null, null) => p + case (null, r: PathType) => r + case (r: PathType, null) => r + case (r1: PathType, r2: PathType) => + pathAliasingMapping = pathAliasingMapping.updated(r2, r1) + r1 + + pathAliasingMapping = pathAliasingMapping.updated(p, newRep) + pathAliasingMapping = pathAliasingMapping.updated(q, newRep) + + override def isAliasingPath(p: PathType, q: PathType): Boolean = + lookupPath(p) match + case null => false + case p0: PathType => lookupPath(q) match + case null => false + case q0: PathType => p0 eq q0 + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = + constraint.isLess(tvarOrError(tp1).origin, tvarOrError(tp2).origin) + override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = mapping(sym) match { case null => null @@ -190,6 +566,18 @@ final class ProperGadtConstraint private( // .ensuring(containsNoInternalTypes(_)) } + override def fullBounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = + tvarOf(p, sym) match { + case null => null + case tv => fullBounds(tv.nn.origin) + } + + override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = + tp match { + case TypeRef(p: PathType, _) => fullBounds(p, tp.symbol) + case _ => null + } + override def bounds(sym: Symbol)(using Context): TypeBounds | Null = mapping(sym) match { case null => null @@ -201,8 +589,37 @@ final class ProperGadtConstraint private( //.ensuring(containsNoInternalTypes(_)) } + override def bounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null = + tvarOf(path, sym) match { + case null => null + case tv: TypeVar => + def retrieveBounds: TypeBounds = + bounds(tv.origin) match { + case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => + TypeAlias(reverseMapping(tpr).nn.typeRef) + case TypeAlias(tpr: TypeParamRef) if pathDepReverseMapping.contains(tpr) => + TypeAlias(pathDepReverseMapping(tpr).nn) + case tb => tb + } + retrieveBounds + } + + override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = + tp match { + case TypeRef(p: PathType, _) => bounds(p, tp.symbol) + case _ => null + } + override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null + override def contains(path: PathType)(using Context): Boolean = pathDepMapping(path) != null + + override def contains(path: PathType, sym: Symbol)(using Context): Boolean = pathDepMapping(path) match + case null => false + case innerMapping => innerMapping.nn(sym) != null + + override def registeredTypeMembers(path: PathType): List[Symbol] = pathDepMapping(path).nn.keys + def isNarrowing: Boolean = wasConstrained override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = { @@ -225,7 +642,12 @@ final class ProperGadtConstraint private( myConstraint, mapping, reverseMapping, - wasConstrained + pathDepMapping, + pathDepReverseMapping, + wasConstrained, + myScrutineePath, + pathAliasingMapping, + myPatternSkolem, ) def restore(other: GadtConstraint): Unit = other match { @@ -233,10 +655,63 @@ final class ProperGadtConstraint private( this.myConstraint = other.myConstraint this.mapping = other.mapping this.reverseMapping = other.reverseMapping + this.pathDepMapping = other.pathDepMapping + this.pathDepReverseMapping = other.pathDepReverseMapping this.wasConstrained = other.wasConstrained + this.myScrutineePath = other.myScrutineePath + this.pathAliasingMapping = other.pathAliasingMapping + this.myPatternSkolem = other.myPatternSkolem case _ => ; } + override def scrutineePath: TermRef | Null = myScrutineePath + + override def resetScrutineePath(): Unit = myScrutineePath = null + + override def supplyPatternPath(path: TermRef)(using Context): Unit = + if myPatternSkolem eq null then + () + else + def updateMappings() = + pathDepMapping(myPatternSkolem.nn) match { + case null => + case m: SimpleIdentityMap[Symbol, TypeVar] => + pathDepMapping = pathDepMapping.updated(path, m) + m foreachBinding { (sym, tvar) => + val tpr = tvar.origin + pathDepReverseMapping = pathDepReverseMapping.updated(tpr, TypeRef(path, sym)) + } + } + + def updateUnionFind() = + pathAliasingMapping(myPatternSkolem.nn) match { + case null => + case repr: PathType => + pathAliasingMapping = pathAliasingMapping.updated(path, repr) + } + + updateMappings() + updateUnionFind() + myPatternSkolem = null + end supplyPatternPath + + override def createPatternSkolem(pat: Type): SkolemType = + if myPatternSkolem ne null then + SkolemType(pat) + else + myPatternSkolem = SkolemType(pat) + myPatternSkolem.nn + end createPatternSkolem + + override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = + val saved = this.myScrutineePath + this.myScrutineePath = path + + val result = op + + this.myScrutineePath = saved + result + // ---- Protected/internal ----------------------------------------------- override protected def constraint = myConstraint @@ -263,19 +738,34 @@ final class ProperGadtConstraint private( // ---- Private ---------------------------------------------------------- - private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match + /** Externalize the internal TypeParamRefs in the given type. + * + * We declare the method as `protected` instead of `private`, because declaring it as + * private will break a pickling test. This is a temporary workaround. + */ + protected def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match case param: TypeParamRef => reverseMapping(param) match case sym: Symbol => sym.typeRef - case null => param + case null => + pathDepReverseMapping(param) match + case tp: TypeRef => tp + case null => param case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap)) case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp) private class ExternalizeMap(using Context) extends TypeMap: def apply(tp: Type): Type = externalize(tp, this)(using mapCtx) + private def tvarOrError(sym: Symbol)(using Context): TypeVar = mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN + private def tvarOrError(path: PathType, sym: Symbol)(using Context): TypeVar = + tvarOf(path, sym).ensuring(_ != null, i"not a constrainable type: $path.$sym").uncheckedNN + + private def tvarOrError(ntp: NamedType)(using Context): TypeVar = + tvarOf(ntp).ensuring(_ != null, i"not a constrainable type: $ntp").uncheckedNN + private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match { case tpr: TypeParamRef => !reverseMapping.contains(tpr) case tv: TypeVar => !reverseMapping.contains(tv.origin) @@ -296,10 +786,25 @@ final class ProperGadtConstraint private( override def debugBoundsDescription(using Context): String = { val sb = new mutable.StringBuilder sb ++= constraint.show - sb += '\n' + sb ++= "\nType parameter bounds:\n" mapping.foreachBinding { case (sym, _) => sb ++= i"$sym: ${fullBounds(sym)}\n" } + sb ++= "\nPath-dependent type bounds:\n" + pathDepMapping foreachBinding { case (path, m) => + m foreachBinding { case (sym, _) => + sb ++= i"$path.$sym: ${fullBounds(TypeRef(path, sym))}\n" + } + } + sb ++= "\nSingleton equalities:\n" + pathAliasingMapping foreachBinding { case (path, _) => + val repr = lookupPath(path) + repr match + case repr: PathType if repr ne path => + sb ++= i"$path.type: $repr.type\n" + case _ => + } + sb.result } } @@ -308,19 +813,49 @@ final class ProperGadtConstraint private( override def bounds(sym: Symbol)(using Context): TypeBounds | Null = null override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = null + override def bounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = null + override def fullBounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = null + override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = null + override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = null + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isNarrowing: Boolean = false override def contains(sym: Symbol)(using Context) = false + override def contains(path: PathType)(using Context) = false + + override def contains(path: PathType, symbol: Symbol)(using Context) = false + + override def isConstrainablePDT(path: PathType, symbol: Symbol)(using Context) = false + + override def registeredTypeMembers(path: PathType): List[Symbol] = Nil + override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") + override def addToConstraint(path: PathType)(using Context): Boolean = false override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") + override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") + + override def recordPathAliasing(p: PathType, q: PathType)(using Context) = () + + override def isAliasingPath(p: PathType, q: PathType) = false override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") override def symbols: List[Symbol] = Nil + override def scrutineePath: TermRef | Null = null + + override def resetScrutineePath(): Unit = () + + override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = op + + override def supplyPatternPath(path: TermRef)(using Context): Unit = () + + override def createPatternSkolem(pat: Type): SkolemType = unsupported("EmptyGadtConstraint.createPatternSkolem") + override def fresh = new ProperGadtConstraint override def restore(other: GadtConstraint): Unit = assert(!other.isNarrowing, "cannot restore a non-empty GADTMap") diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index c5f126580df5..9dba2c81fac0 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -10,6 +10,7 @@ import Contexts.ctx import dotty.tools.dotc.reporting.trace import config.Feature.migrateTo3 import config.Printers._ +import dotty.tools.dotc.core.SymDenotations.NoDenotation trait PatternTypeConstrainer { self: TypeComparer => @@ -73,129 +74,261 @@ trait PatternTypeConstrainer { self: TypeComparer => * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables, * in which case the subtyping relationship "heals" the type. */ - def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) { - - def classesMayBeCompatible: Boolean = { - import Flags._ - val patCls = pat.classSymbol - val scrCls = scrut.classSymbol - !patCls.exists || !scrCls.exists || { - if (patCls.is(Final)) patCls.derivesFrom(scrCls) - else if (scrCls.is(Final)) scrCls.derivesFrom(patCls) - else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait)) - patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls) - else true + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + def recur(pat: Type, scrut: Type): Boolean = { + def classesMayBeCompatible: Boolean = { + import Flags._ + val patCls = pat.classSymbol + val scrCls = scrut.classSymbol + !patCls.exists || !scrCls.exists || { + if (patCls.is(Final)) patCls.derivesFrom(scrCls) + else if (scrCls.is(Final)) scrCls.derivesFrom(patCls) + else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait)) + patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls) + else true + } } - } - def stripRefinement(tp: Type): Type = tp match { - case tp: RefinedOrRecType => stripRefinement(tp.parent) - case tp => tp - } + def stripRefinement(tp: Type): Type = tp match { + case tp: RefinedOrRecType => stripRefinement(tp.parent) + case tp => tp + } - def tryConstrainSimplePatternType(pat: Type, scrut: Type) = { - val patCls = pat.classSymbol - val scrCls = scrut.classSymbol - patCls.exists && scrCls.exists - && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)) - && constrainSimplePatternType(pat, scrut, forceInvariantRefinement) - } + def tryConstrainSimplePatternType(pat: Type, scrut: Type) = { + val patCls = pat.classSymbol + val scrCls = scrut.classSymbol + patCls.exists && scrCls.exists + && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)) + && constrainSimplePatternType(pat, scrut, forceInvariantRefinement) + } + + def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) { + // Fold a list of types into an AndType + def buildAndType(xs: List[Type]): Type = { + @annotation.tailrec def recur(acc: Type, rem: List[Type]): Type = rem match { + case Nil => acc + case x :: rem => recur(AndType(acc, x), rem) + } + xs match { + case Nil => NoType + case x :: xs => recur(x, xs) + } + } + + scrut match { + case scrut: TypeRef if scrut.symbol.isClass => + // consider all parents + val parents = scrut.parents + val andType = buildAndType(parents) + !andType.exists || recur(pat, andType) + case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass => + val patCls = pat.classSymbol + // find all shared parents in the inheritance hierarchy between pat and scrut + def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = { + var parents = tpClassSym.info.parents + if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then + parents = parents.tail + parents flatMap { tp => + val sym = tp.classSymbol.asClass + if patCls.derivesFrom(sym) then List(sym) + else allParentsSharedWithPat(tp, sym) + } + } + val allSyms = allParentsSharedWithPat(tycon, tycon.symbol.asClass) + val baseClasses = allSyms map scrut.baseType + val andType = buildAndType(baseClasses) + !andType.exists || recur(pat, andType) + case _ => + def tryGadtBounds = scrut match { + case scrut: TypeRef => + ctx.gadt.bounds(scrut.symbol) match { + case tb: TypeBounds => + val hi = tb.hi + recur(pat, hi) + case null => true + } + case _ => true + } + + def trySuperType = + val upcasted: Type = scrut match { + case scrut: TypeProxy => + scrut.superType + case _ => NoType + } + if (upcasted.exists) + tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted) + else true - def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) { - // Fold a list of types into an AndType - def buildAndType(xs: List[Type]): Type = { - @annotation.tailrec def recur(acc: Type, rem: List[Type]): Type = rem match { - case Nil => acc - case x :: rem => recur(AndType(acc, x), rem) + tryGadtBounds && trySuperType } - xs match { - case Nil => NoType - case x :: xs => recur(x, xs) + } + + def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match { + case tp: TermRef => + // we drop TermRefs that don't have a class symbol, as they can't + // meaningfully participate in GADT reasoning and just get in the way. + // Their info could, for an example, be an AndType. One example where + // this is important is an enum case that extends its parent and an + // additional trait - argument-less enum cases desugar to vals. + // See run/enum-Tree.scala. + if tp.classSymbol.exists then tp else tp.info + case tp => tp + } + + dealiasDropNonmoduleRefs(scrut) match { + case OrType(scrut1, scrut2) => + either(recur(pat, scrut1), recur(pat, scrut2)) + case AndType(scrut1, scrut2) => + recur(pat, scrut1) && recur(pat, scrut2) + case scrut: RefinedOrRecType => + recur(pat, stripRefinement(scrut)) + case scrut => dealiasDropNonmoduleRefs(pat) match { + case OrType(pat1, pat2) => + either(recur(pat1, scrut), recur(pat2, scrut)) + case AndType(pat1, pat2) => + recur(pat1, scrut) && recur(pat2, scrut) + case pat: RefinedOrRecType => + recur(stripRefinement(pat), scrut) + case pat => + tryConstrainSimplePatternType(pat, scrut) + || classesMayBeCompatible && constrainUpcasted(scrut) } } + } + + /** Reconstruct subtype constraints for type members of the scrutinee and the pattern. + * + * To inference SR constraints for the type members from the scrutinee `p` and the pattern `q`, + * we first find all the abstract type members of `p`: A₁, A₂, ⋯, Aₖ. + * If these path-dependent types are not registered in the handler, we will register them. + * + * Then, for each Aᵢ, if `q` also has a type member labaled `Aᵢ`, we inference SR constraints by calling + * TypeComparer on the relation `p.Aᵢ <:< q.Aᵢ`. + * We derive SR constraints for type members of the pattern path `q` similarly. + * + * Specially, if for some `Aᵢ`, `p.Aᵢ` is abstract while `q.Aᵢ` is not, we will extract constraints + * for both directions of the subtype relations (i.e. both `p.Aᵢ <:< q.Aᵢ` and `q.Aᵢ <:< p.Aᵢ`). + * + * How we find out and handle the path (`TermRef`) of the scrutinee and pattern: + * + * - The path of scrutinee is not directly available in `constrainPatternType`, since the scrutinee type passed to this function is widened. + * To have access to the scrutinee path here, we save the scrutinee path in `Typer.typedCase` with `GadtConstraint.withScrutineePath`, + * and the scrutinee path will be accessible as `ctx.gadt.scrutineePath`. + * Note that we have to reset the saved scrutinee path to `null` after using by calling `ctx.gadt.resetScrutineePath()`. + * This is because `constrainPatternType` may be called multiple times for one nested pattern. For example: + * + * e match + * case A(B(a), C(b)) => // ... + * + * We have to reset the scrutinee path after constraining `e` against the top level pattern `A(...)`. + * + * - The path of pattern is not available when calling the function, and the symbol of the pattern will only be created after GADT reasoning. + * Therefore, we will create `SkolemType` acting as a placeholder for the pattern path, and substitute it with the real pattern path + * when it is available later in the typer. We call `ctx.gadt.supplyPatternPath` to do the substitution. + */ + def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + import NameKinds.DepParamName + val realScrutineePath = ctx.gadt.scrutineePath + + // We reset scrutinee path so that the path will only be used at top level. + ctx.gadt.resetScrutineePath() + + val saved = state.nn.constraint + val savedGadt = ctx.gadt.fresh + + val scrutineePath: TermRef | SkolemType = realScrutineePath match + case null => SkolemType(scrut) + case _ => realScrutineePath + val patternPath: SkolemType = ctx.gadt.createPatternSkolem(pat) - scrut match { - case scrut: TypeRef if scrut.symbol.isClass => - // consider all parents - val parents = scrut.parents - val andType = buildAndType(parents) - !andType.exists || constrainPatternType(pat, andType) - case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass => - val patCls = pat.classSymbol - // find all shared parents in the inheritance hierarchy between pat and scrut - def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = { - var parents = tpClassSym.info.parents - if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then - parents = parents.tail - parents flatMap { tp => - val sym = tp.classSymbol.asClass - if patCls.derivesFrom(sym) then List(sym) - else allParentsSharedWithPat(tp, sym) + val registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) + // Pattern path is a freshly-created skolem, + // so it will always be un-registered at this point + val registerPattern = ctx.gadt.addToConstraint(patternPath) + + /** Reconstruct subtype constraints for a path `p`, given that `p` and `q` + are cohabitated. + + When do SR for each type member (denoted as `T`) of path p, there are the + following three cases: + + (1) q does not have number T. In this case we should simply return true. + + (2) q.T is a registered type member. We do SR on p.T <:< q.T, but not + q.T <:< p.T, since if q.T is also registered then + `constrainTypeMember(q, p, T)` will also be called, during which + q.T <:< p.T will be handled. + + (3) q.T is unregistered. We will do SR on p.T <:< q.T and q.T <:< p.T. + */ + def reconstructSubType(p: PathType, q: PathType) = + def processMember(sym: Symbol): Boolean = + q.member(sym.name).isInstanceOf[NoDenotation.type] || { + val pType = TypeRef(p, sym) + val qType = TypeRef(q, sym) + + trace(i"constrainTypeMember $pType >:< $qType", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + if ctx.gadt.contains(q, sym) then + isSubType(pType, qType) + else + isSubType(pType, qType) && isSubType(qType, pType) } } - val allSyms = allParentsSharedWithPat(tycon, tycon.symbol.asClass) - val baseClasses = allSyms map scrut.baseType - val andType = buildAndType(baseClasses) - !andType.exists || constrainPatternType(pat, andType) - case _ => - def tryGadtBounds = scrut match { - case scrut: TypeRef => - ctx.gadt.bounds(scrut.symbol) match { - case tb: TypeBounds => - val hi = tb.hi - constrainPatternType(pat, hi) - case null => true - } - case _ => true - } - def trySuperType = - val upcasted: Type = scrut match { - case scrut: TypeProxy => - scrut.superType - case _ => NoType - } - if (upcasted.exists) - tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted) - else true + ctx.gadt.registeredTypeMembers(p) forall { sym => processMember(sym) } - tryGadtBounds && trySuperType + /** Reconstruct subtype from the cohabitation between the scrutinee and the + pattern. */ + def constrainPattern: Boolean = { + ctx.gadt.recordPathAliasing(scrutineePath, patternPath) + + (!registerPattern || reconstructSubType(patternPath, scrutineePath)) + && (!registerScrutinee || reconstructSubType(scrutineePath, patternPath)) } - } - def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match { - case tp: TermRef => - // we drop TermRefs that don't have a class symbol, as they can't - // meaningfully participate in GADT reasoning and just get in the way. - // Their info could, for an example, be an AndType. One example where - // this is important is an enum case that extends its parent and an - // additional trait - argument-less enum cases desugar to vals. - // See run/enum-Tree.scala. - if tp.classSymbol.exists then tp else tp.info - case tp => tp - } + /** Reconstruct subtype when the pattern is an alias to another path. - dealiasDropNonmoduleRefs(scrut) match { - case OrType(scrut1, scrut2) => - either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2)) - case AndType(scrut1, scrut2) => - constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2) - case scrut: RefinedOrRecType => - constrainPatternType(pat, stripRefinement(scrut)) - case scrut => dealiasDropNonmoduleRefs(pat) match { - case OrType(pat1, pat2) => - either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut)) - case AndType(pat1, pat2) => - constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut) - case pat: RefinedOrRecType => - constrainPatternType(stripRefinement(pat), scrut) - case pat => - tryConstrainSimplePatternType(pat, scrut) - || classesMayBeCompatible && constrainUpcasted(scrut) + For example, consider the following pattern match: + + p match + case q: r.type => + + Then we can also reconstruct subtype from the cohabitation of p and r. + */ + def maybeConstrainPatternAlias: Boolean = pat match { + case ptPath: TermRef => + val registerPtPath = ctx.gadt.contains(ptPath) || ctx.gadt.addToConstraint(ptPath) + + val result = + (!registerPtPath || reconstructSubType(ptPath, scrutineePath)) + && (!registerScrutinee || reconstructSubType(scrutineePath, ptPath)) + + ctx.gadt.recordPathAliasing(scrutineePath, ptPath) + + result + case _ => + true } + + val res = constrainPattern && maybeConstrainPatternAlias + + if !res then + constraint = saved + ctx.gadt.restore(savedGadt) + + res } + + recur(pat, scrut) && constrainTypeMembers } + /** Show the scrutinee. Will show the path if available. */ + private def scrutRepr(scrut: Type): String = + ctx.gadt.scrutineePath match + case null => scrut.show + case p: PathType => p.show + /** Constrain "simple" patterns (see `constrainPatternType`). * * This function expects to receive two types (scrutinee and pattern), both diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index adce363dc3f4..3d66057eefa7 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -118,7 +118,22 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def isBottom(tp: Type) = tp.widen.isRef(NothingClass) protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym) - protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper) + + /** This variant of gadtBounds works for all named types. + * It queries GADT bounds for both type parameters and path-dependent types. + */ + protected def gadtBounds(tp: NamedType)(using Context) = + ctx.gadt.bounds(tp.symbol) match + case null => + tp match + case TypeRef(p: PathType, _) => ctx.gadt.bounds(p, tp.symbol) + case _ => null + case tb => tb + + protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper = isUpper) + + protected def gadtAddLowerBound(path: PathType, sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(path, sym, b, isUpper = false) + protected def gadtAddUpperBound(path: PathType, sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(path, sym, b, isUpper = true) protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying @@ -193,6 +208,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling val bounds = gadtBounds(sym) bounds != null && op(bounds) + extension (tp: NamedType) + private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = + val bounds = gadtBounds(tp) + bounds != null && op(bounds) + private inline def comparingTypeLambdas(tl1: TypeLambda, tl2: TypeLambda)(op: => Boolean): Boolean = val saved = comparedTypeLambdas comparedTypeLambdas += tl1 @@ -426,7 +446,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => false } || isSubTypeWhenFrozen(bounds(tp1).hi.boxed, tp2) || { - if (canConstrain(tp1) && !approx.high) + if canConstrain(tp1) && isPreciseBound(fromBelow = false) then addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound else thirdTry } @@ -540,19 +560,28 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match { case info2: TypeBounds => + /** Checks whether tp1 is registered. + * Both type parameters and path-dependent types are considered. + */ + def tpRegistered(tp: TypeRef) = ctx.gadt.contains(tp.symbol) || { + tp match + case tp @ TypeRef(p: PathType, _) => ctx.gadt.contains(p, tp.symbol) + case _ => false + } + def compareGADT: Boolean = - tp2.symbol.onGadtBounds(gbounds2 => - isSubTypeWhenFrozen(tp1, gbounds2.lo) - || tp1.match - case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => - // Note: since we approximate constrained types only with their non-param bounds, - // we need to manually handle the case when we're comparing two constrained types, - // one of which is constrained to be a subtype of another. - // We do not need similar code in fourthTry, since we only need to care about - // comparing two constrained types, and that case will be handled here first. - ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) - case _ => false - || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) + { tp2.onGadtBounds(gbounds2 => + { isSubTypeWhenFrozen(tp1, gbounds2.lo) } + || tp1.match + case tp1: TypeRef if tpRegistered(tp1) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + { ctx.gadt.isLess(tp1, tp2) } && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false) + || narrowGADTBounds(tp2, tp1, approx, isUpper = false) } && (isBottom(tp1) || GADTusage(tp2.symbol)) isSubApproxHi(tp1, info2.lo.boxedIfTypeParam(tp2.symbol)) && (trustBounds || isSubApproxHi(tp1, info2.hi)) @@ -561,6 +590,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling || fourthTry case _ => + def compareSingletonGADT: Boolean = + (tp1, tp2) match { + case (tp1: TermRef, tp2: TermRef) => + ctx.gadt.isAliasingPath(tp1, tp2) && { GADTused = true; true } + case _ => false + } + val cls2 = tp2.symbol if (cls2.isClass) if (cls2.typeParams.isEmpty) { @@ -581,9 +617,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } else if tp1.isLambdaSub && !tp1.isAnyKind then return recur(tp1, EtaExpansion(tp2)) + + if compareSingletonGADT then return true + fourthTry } + def isPreciseBound(fromBelow: Boolean): Boolean = + if ctx.mode.is(Mode.GadtConstraintInference) then + !(approx.low || approx.high) + else + if fromBelow then !approx.low else !approx.high + def compareTypeParamRef(tp2: TypeParamRef): Boolean = assumedTrue(tp2) || { val alwaysTrue = @@ -597,7 +642,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling if (frozenConstraint) recur(tp1, bounds(tp2).lo.boxed) else isSubTypeWhenFrozen(tp1, tp2) alwaysTrue || { - if (canConstrain(tp2) && !approx.low) + if canConstrain(tp2) && isPreciseBound(fromBelow = true) then addConstraint(tp2, tp1.widenExpr, fromBelow = true) else fourthTry } @@ -858,10 +903,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.info match { case info1 @ TypeBounds(lo1, hi1) => def compareGADT = - tp1.symbol.onGadtBounds(gbounds1 => - isSubTypeWhenFrozen(gbounds1.hi, tp2) - || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) - && (tp2.isAny || GADTusage(tp1.symbol)) + { tp1.onGadtBounds(gbounds1 => + isSubTypeWhenFrozen(gbounds1.hi, tp2)) + || narrowGADTBounds(tp1, tp2, approx, isUpper = true) + } && (tp2.isAny || GADTusage(tp1.symbol)) (!caseLambda.exists || canWidenAbstract) && isSubType(hi1.boxedIfTypeParam(tp1.symbol), tp2, approx.addLow) && (trustBounds || isSubType(lo1, tp2, approx.addLow)) @@ -1179,10 +1224,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling var touchedGADTs = false var gadtIsInstantiated = false - extension (sym: Symbol) + extension (tp: TypeRef) inline def byGadtBounds(inline op: TypeBounds => Boolean): Boolean = touchedGADTs = true - sym.onGadtBounds( + tp.onGadtBounds( b => op(b) && { gadtIsInstantiated = b.isInstanceOf[TypeAlias]; true }) def byGadtOrdering: Boolean = @@ -1190,11 +1235,19 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling && ctx.gadt.contains(tycon2sym) && ctx.gadt.isLess(tycon1sym, tycon2sym) + def byPathDepGadtOrdering: Boolean = + (tycon1, tycon2) match + case (TypeRef(p1: PathType, _), TypeRef(p2: PathType, _)) => + ctx.gadt.contains(p1, tycon1sym) + && ctx.gadt.contains(p2, tycon2sym) + && ctx.gadt.isLess(tycon1, tycon2) + case _ => false + val res = ( tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix) - || tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2)) - || tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo)) - || byGadtOrdering + || tycon1.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2)) + || tycon2.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo)) + || byGadtOrdering || byPathDepGadtOrdering ) && { // There are two cases in which we can assume injectivity. // First we check if either sym is a class. @@ -1843,10 +1896,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling ctx.gadt.restore(preGadt) if op2 then if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then - gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt") + gadts.println(i"GADT CUT - prefer op2 ${ctx.gadt} over $op1Gadt") constr.println(i"CUT - prefer $constraint over $op1Constraint") else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then - gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}") + gadts.println(i"GADT CUT - prefer op1 $op1Gadt over ${ctx.gadt}") constr.println(i"CUT - prefer $op1Constraint over $constraint") constraint = op1Constraint ctx.gadt.restore(op1Gadt) @@ -2017,6 +2070,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => proto.isMatchedBy(tp, keepConstraint = true) } + private def rollbackGadtUnless(op: => Boolean): Boolean = + val savedGadt = ctx.gadt.fresh + val result = op + if !result then ctx.gadt.restore(savedGadt) + result + end rollbackGadtUnless + /** Narrow gadt.bounds for the type parameter referenced by `tr` to include * `bound` as an upper or lower bound (which depends on `isUpper`). * Test that the resulting bounds are still satisfiable. @@ -2024,14 +2084,66 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && { - val tparam = tr.symbol - gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") - if (bound.isRef(tparam)) false - else - val savedGadt = ctx.gadt.fresh - val success = gadtAddBound(tparam, bound, isUpper) - if !success then ctx.gadt.restore(savedGadt) - success + def tryRegisterBound: Boolean = bound.match { + case tr @ TypeRef(path: PathType, _) => + val sym = tr.symbol + + def register = + ctx.gadt.contains(path, sym) || ctx.gadt.contains(sym) || { + ctx.gadt.isConstrainablePDT(path, tr.symbol) && { + gadts.println(i"!!! registering path on the fly path=$path sym=$sym") + ctx.gadt.addToConstraint(path) && ctx.gadt.contains(path, sym) + } + } + + val result = register + + true + case _ => true + } + + def narrowTypeParams = ctx.gadt.contains(tr.symbol) && { + val tparam = tr.symbol + gadts.println(i"narrow gadt bound of tparam $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") + if (bound.isRef(tparam)) false + else + rollbackGadtUnless { + if isUpper then + gadtAddBound(tparam, bound, isUpper = true) + else + gadtAddBound(tparam, bound, isUpper = false) + } + } + + def narrowPathDepType = tr match + case TypeRef(path: PathType, _) => + val sym = tr.symbol + + def isConstrainable: Boolean = + ctx.gadt.contains(path, sym) || { + ctx.gadt.isConstrainablePDT(path, tr.symbol) && { + gadts.println(i"!!! registering path on the fly path=$path sym=$sym") + ctx.gadt.addToConstraint(path) && ctx.gadt.contains(path, sym) + } + } + + def isRef: Boolean = bound match + case TypeRef(q: PathType, _) => (path eq q) && bound.isRef(sym) + case _ => false + + rollbackGadtUnless { + isConstrainable && { + gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${isRef}") + + if isRef then false + else if isUpper then gadtAddUpperBound(path, sym, bound) + else gadtAddLowerBound(path, sym, bound) + } + } + + case _ => false + + tryRegisterBound && narrowTypeParams || narrowPathDepType } } @@ -3047,6 +3159,21 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) { if (sym.exists) footprint += sym.typeRef super.gadtAddBound(sym, b, isUpper) + override def gadtBounds(tp: NamedType)(using Context): TypeBounds | Null = { + if (tp.symbol.exists) footprint += tp + super.gadtBounds(tp) + } + + override def gadtAddLowerBound(path: PathType, sym: Symbol, b: Type): Boolean = { + if (sym.exists) footprint += TypeRef(path, sym) + super.gadtAddLowerBound(path, sym, b) + } + + override def gadtAddUpperBound(path: PathType, sym: Symbol, b: Type): Boolean = { + if (sym.exists) footprint += TypeRef(path, sym) + super.gadtAddUpperBound(path, sym, b) + } + override def typeVarInstance(tvar: TypeVar)(using Context): Type = { footprint += tvar super.typeVarInstance(tvar) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index a73c73863606..65bc798865ec 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1750,7 +1750,43 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1) } - val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx) + val scrutineePath = + sel.tpe match { + case p: TermRef => + tree.pat match { + case _: (Trees.Typed[_] | Trees.Ident[_] | Trees.Apply[_] | Trees.Bind[_]) => + // We only record scrutinee path in the above cases, b/c recording + // it in all cases may lead to unsoundness. + // + // For example: + // + // def foo(e: (Expr, Expr)) = e match + // case (e1: Expr, e2: Expr) => + // + // Here the pattern is a tuple. `constrainPatternType` will be called + // on the two elements of the tuple directly, without constraining + // `e` and the whole tuple first. + // Therefore, recording the scrutinee path in this case can give + // us constraints like `e1.type == e.type`, which is not true. + p + case _ => + null + } + case _ => null + } + + // Save the scrutinee path and then type the pattern. + // The scrutinee path will be used in SR reasoning for path-dependent types. + // See `constrainTypeMembers` in `PatternTypeConstrainer`. + val pat1 = gadtCtx.gadt.withScrutineePath(scrutineePath) { + typedPattern(tree.pat, wideSelType)(using gadtCtx) + } + + if scrutineePath.ne(null) && pat1.symbol.isPatternBound then + // Subtitute the place holder with real pattern path in GADT constraints. + // See `constrainTypeMembers` in `PatternTypeConstrainer`. + gadtCtx.gadt.supplyPatternPath(pat1.symbol.termRef) + caseRest(pat1)( using Nullables.caseContext(sel, pat1)( using gadtCtx)) diff --git a/tests/neg/i15958.scala b/tests/neg/i15958.scala new file mode 100644 index 000000000000..b5a52b01e618 --- /dev/null +++ b/tests/neg/i15958.scala @@ -0,0 +1,29 @@ +sealed trait NatT { type This <: NatT } +case class Zero() extends NatT { + type This = Zero +} +case class Succ[N <: NatT](n: N) extends NatT { + type This = Succ[n.This] +} + +trait IsLessThan[+M <: NatT, N <: NatT] +object IsLessThan: + given base[M <: NatT]: IsLessThan[M, Succ[M]]() + given weakening[N <: NatT, M <: NatT] (using IsLessThan[N, M]): IsLessThan[N, Succ[M]]() + given reduction[N <: NatT, M <: NatT] (using IsLessThan[Succ[N], Succ[M]]): IsLessThan[N, M]() + +sealed trait UniformTuple[Length <: NatT, T]: + def apply[M <: NatT](m: M)(using IsLessThan[m.This, Length]): T + +case class Empty[T]() extends UniformTuple[Zero, T]: + def apply[M <: NatT](m: M)(using IsLessThan[m.This, Zero]): T = throw new AssertionError("Uncallable") + +case class Cons[N <: NatT, T](head: T, tail: UniformTuple[N, T]) extends UniformTuple[Succ[N], T]: + def apply[M <: NatT](m: M)(using proof: IsLessThan[m.This, Succ[N]]): T = m match + case Zero() => head + case m1: Succ[predM] => + val proof1: IsLessThan[m1.This, Succ[N]] = proof + + val res0 = tail(m1.n)(using IsLessThan.reduction(using proof)) // error // limitation + val res1 = tail(m1.n)(using IsLessThan.reduction(using proof1)) + res1 diff --git a/tests/neg/pdgadt-either.scala b/tests/neg/pdgadt-either.scala new file mode 100644 index 000000000000..5d53f0085ae8 --- /dev/null +++ b/tests/neg/pdgadt-either.scala @@ -0,0 +1,17 @@ +trait T1 +trait T2 extends T1 +// T2 <:< T1 + +trait Expr[+X] +case class Tag1() extends Expr[T1] +case class Tag2() extends Expr[T2] + +trait TypeTag { type A } + +def foo(p: TypeTag, e: Expr[p.A]) = e match + case _: (Tag2 | Tag1) => + // Tag1: (T2 <:) T1 <: p.A + // Tag2: T2 <: p.A + // should choose T2 <: p.A + val t1: p.A = ??? : T1 // error + val t2: p.A = ??? : T2 diff --git a/tests/neg/pdgadt-nested-pat-alias.scala b/tests/neg/pdgadt-nested-pat-alias.scala new file mode 100644 index 000000000000..3c33a367f40b --- /dev/null +++ b/tests/neg/pdgadt-nested-pat-alias.scala @@ -0,0 +1,12 @@ +trait Expr { type T } +case class B(b: Int) extends Expr { type T = Int } +case class C(c: Boolean) extends Expr { type T = Boolean } +case class A(a: Expr, b: Expr) extends Expr { type T = (a.T, b.T) } + +def foo(e: Expr) = e match + case e1 @ A(b1 @ B(b), c1 @ C(c)) => + val t1: e1.type = e + val t2: e.type = e1 + + val t3: b1.type = e // error + val t4: c1.type = e // error diff --git a/tests/neg/pdgadt-patalias.scala b/tests/neg/pdgadt-patalias.scala new file mode 100644 index 000000000000..1a6dbcc44ccb --- /dev/null +++ b/tests/neg/pdgadt-patalias.scala @@ -0,0 +1,38 @@ +trait Expr { type X } +trait IntExpr { type X = Int } + +val iexpr: Expr { type X = Int } = new Expr { type X = Int } + +object Zero { type X = 0 } + +def direct1(x: Expr) = x match { + case _: iexpr.type => + val x1: Int = ??? : x.X + val x2: x.X = ??? : Int + val x3: iexpr.type = x +} + +def direct2(x: Expr) = x match { + case _: Zero.type => + val x1: Int = ??? : x.X // limitation // error + val x2: x.X = 0 // limitation // error +} + +def indirect1(a: Expr, b: IntExpr) = a match { + case r: b.type => + val x1: Int = ??? : a.X + val x2: a.X = ??? : Int + val x3: a.type = b +} + +def indirect2(a: Expr, b: Expr) = a match { + case _: IntExpr => + // sanity check + val x1: Int = ??? : a.X + val x2: a.X = ??? : Int + b match { + case _: a.type => + val x1: Int = ??? : b.X + val x2: b.X = ??? : Int + } +} diff --git a/tests/neg/pdgadt-reuse.scala b/tests/neg/pdgadt-reuse.scala new file mode 100644 index 000000000000..67444f5ab661 --- /dev/null +++ b/tests/neg/pdgadt-reuse.scala @@ -0,0 +1,16 @@ +object test { + trait Result[A] + + def cached[A](f: (x: A) => Result[x.type], toCache: A): (x: A) => Result[x.type] = { + val cachedRes: Result[toCache.type] = f(toCache) + def resFunc(x: A): Result[x.type] = { + x match { + case r: toCache.type => cachedRes + case _ => + val res: Result[x.type] = cachedRes // error + f(x) + } + } + resFunc + } +} diff --git a/tests/neg/pdgadt-singletons.scala b/tests/neg/pdgadt-singletons.scala new file mode 100644 index 000000000000..24650807bd2e --- /dev/null +++ b/tests/neg/pdgadt-singletons.scala @@ -0,0 +1,30 @@ +object test { + trait Tag + + locally { + val p0: Tag = ??? + val p1: Tag = ??? + + p0 match { + case q0: Tag => + val x1: p0.type = q0 + val x2: q0.type = p0 + val x3: p0.type = p1 // error + + p1 match { + case q1: Tag => + val x1: p1.type = q1 + val x2: q1.type = p1 + val x3: p0.type = p1 // error + + q0 match { + case r: q1.type => + val x1: q0.type = q1 + val x2: p0.type = p1 + val x3: q0.type = p1 + val x4: q1.type = p0 + } + } + } + } +} diff --git a/tests/neg/structural-gadt.scala b/tests/neg/structural-gadt.scala index 9a14881b5804..c970f08709de 100644 --- a/tests/neg/structural-gadt.scala +++ b/tests/neg/structural-gadt.scala @@ -1,7 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. -// -// Lines with "// limitation" are the ones that we could soundly allow. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. object Test { trait Expr { type T } @@ -11,19 +9,19 @@ object Test { def foo[A](e: Expr { type T = A }) = e match { case _: IntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: Expr { type T <: Int } => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: Expr { type T = Int } => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: Expr { type T <: A }) = e match { @@ -36,11 +34,11 @@ object Test { val i: Int = ??? : A // error case _: IntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: Expr { type T = Int } => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } diff --git a/tests/neg/structural-recursive-both1-gadt.scala b/tests/neg/structural-recursive-both1-gadt.scala index 97df59a92bb5..ae17f01ff12a 100644 --- a/tests/neg/structural-recursive-both1-gadt.scala +++ b/tests/neg/structural-recursive-both1-gadt.scala @@ -1,5 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. // // Lines with "// limitation" are the ones that we could soundly allow. object Test { @@ -28,23 +28,23 @@ object Test { def foo[A](e: IndirectExprExact[A]) = e match { case _: AltIndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: AltIndirectExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: IndirectExprSub[A]) = e match { @@ -83,11 +83,11 @@ object Test { val i: Int = ??? : A // error case _: AltIndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: AltIndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-both2-gadt.scala b/tests/neg/structural-recursive-both2-gadt.scala index b58e05f3ed43..08f93aa8469c 100644 --- a/tests/neg/structural-recursive-both2-gadt.scala +++ b/tests/neg/structural-recursive-both2-gadt.scala @@ -1,5 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. // // Lines with "// limitation" are the ones that we could soundly allow. object Test { @@ -28,23 +28,23 @@ object Test { def foo[A](e: AltIndirectExprExact[A]) = e match { case _: IndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: AltIndirectExprSub[A]) = e match { @@ -83,11 +83,11 @@ object Test { val i: Int = ??? : A // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-pattern-gadt.scala b/tests/neg/structural-recursive-pattern-gadt.scala index ea7394b5b66b..a66f2b6f43f8 100644 --- a/tests/neg/structural-recursive-pattern-gadt.scala +++ b/tests/neg/structural-recursive-pattern-gadt.scala @@ -1,7 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. -// -// Lines with "// limitation" are the ones that we could soundly allow. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. object Test { trait Expr { type T } @@ -28,23 +26,23 @@ object Test { def foo[A](e: ExprExact[A]) = e match { case _: IndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: ExprSub[A]) = e match { @@ -61,11 +59,11 @@ object Test { val i: Int = ??? : A // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-scrutinee-gadt.scala b/tests/neg/structural-recursive-scrutinee-gadt.scala index cd4e2376f49a..f5f8c29cc314 100644 --- a/tests/neg/structural-recursive-scrutinee-gadt.scala +++ b/tests/neg/structural-recursive-scrutinee-gadt.scala @@ -1,5 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. // // Lines with "// limitation" are the ones that we could soundly allow. object Test { @@ -28,19 +28,19 @@ object Test { def foo[A](e: IndirectExprExact[A]) = e match { case _: IntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: ExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: ExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: IndirectExprSub[A]) = e match { @@ -71,11 +71,11 @@ object Test { val i: Int = ??? : A // error case _: IntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: ExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/pos/gadt-dep-param.scala b/tests/pos/gadt-dep-param.scala new file mode 100644 index 000000000000..cbdd6db338a5 --- /dev/null +++ b/tests/pos/gadt-dep-param.scala @@ -0,0 +1,12 @@ +enum Tup[A, B]: + case Data[A, B]() extends Tup[A, B] + +def foo1[A, B](e: Tup[A, B]) = e.match { + case _: Tup.Data[a, b] => + def bar[C >: a <: b, D]() = + val t1: b = ??? : a +} + +def foo2[A, B, C >: A <: B]() = + val t1: B = ??? : A + diff --git a/tests/pos/pdgadt-asmember.scala b/tests/pos/pdgadt-asmember.scala new file mode 100644 index 000000000000..6a2157a0abdb --- /dev/null +++ b/tests/pos/pdgadt-asmember.scala @@ -0,0 +1,14 @@ +// Taken from https://github.com/lampepfl/dotty/pull/14754#issuecomment-1157427912. +trait T[X] +case object Foo extends T[Unit] + +trait AsMember { + type L + val tl: T[L] +} + +def testMember(am: AsMember): Unit = + am.tl match { + case Foo => println(summon[am.L =:= Unit]) + case _ => () + } diff --git a/tests/pos/pdgadt-expr.scala b/tests/pos/pdgadt-expr.scala new file mode 100644 index 000000000000..7bf017c8a312 --- /dev/null +++ b/tests/pos/pdgadt-expr.scala @@ -0,0 +1,13 @@ +type typed[E <: Expr, V] = E & { type T = V } + +trait Expr { type T } +case class LitInt(x: Int) extends Expr { type T = Int } +case class Add(e1: Expr typed Int, e2: Expr typed Int) extends Expr { type T = Int } +case class LitBool(x: Boolean) extends Expr { type T = Boolean } +case class Or(e1: Expr typed Boolean, e2: Expr typed Boolean) extends Expr { type T = Boolean } + +def eval(e: Expr): e.T = e match + case LitInt(x) => x + case Add(e1, e2) => eval(e1) + eval(e2) + case LitBool(b) => b + case Or(e1, e2) => eval(e1) || eval(e2) diff --git a/tests/pos/pdgadt-hkt-bounds.scala b/tests/pos/pdgadt-hkt-bounds.scala new file mode 100644 index 000000000000..4a46a617a33f --- /dev/null +++ b/tests/pos/pdgadt-hkt-bounds.scala @@ -0,0 +1,11 @@ +type Const = [X] =>> Int + +trait Expr { type F[_] } +case class ConstExprHi() extends Expr { type F[a] <: Const[a] } +case class ConstExprLo() extends Expr { type F[a] >: Const[a] } + +def foo[A](e: Expr) = e match + case _: ConstExprHi => + val i: Int = (??? : e.F[A]) : Const[A] + case _: ConstExprLo => + val i: Const[A] = ??? : Int diff --git a/tests/pos/pdgadt-hkt-ordering.scala b/tests/pos/pdgadt-hkt-ordering.scala new file mode 100644 index 000000000000..ea3879e4db14 --- /dev/null +++ b/tests/pos/pdgadt-hkt-ordering.scala @@ -0,0 +1,12 @@ +object test { + trait Tag { type F[_] } + + enum SubK[-A[_], +B[_]]: + case Refl[F[_]]() extends SubK[F, F] + + def foo(p: Tag, q: Tag, e: SubK[p.F, q.F]) = p match + case _: Tag => q match + case _: Tag => e match + case SubK.Refl() => + val t: q.F[Int] = ??? : p.F[Int] + } diff --git a/tests/pos/pdgadt-hkt-usage.scala b/tests/pos/pdgadt-hkt-usage.scala new file mode 100644 index 000000000000..97d9c6e07bc7 --- /dev/null +++ b/tests/pos/pdgadt-hkt-usage.scala @@ -0,0 +1,14 @@ +object test { + class Foo[A] + class Inv { type F[_] } + class InvFoo extends Inv { type F[x] = Foo[x] } + + object Test { + def foo(x: Inv) = x match { + case x: InvFoo => + val z1: x.F[Int] = ??? : Foo[Int] + val z2: Foo[Int] = ??? : x.F[Int] + case _ => + } + } +} diff --git a/tests/pos/pdgadt-nat-simpleadd.scala b/tests/pos/pdgadt-nat-simpleadd.scala new file mode 100644 index 000000000000..f87e5a644767 --- /dev/null +++ b/tests/pos/pdgadt-nat-simpleadd.scala @@ -0,0 +1,40 @@ +trait Z +trait S[N] + +type of[A, B] = A & { type T = B } + +trait Nat { type T } +case class Zero() extends Nat { type T = Z } +case class Succ[P](prec: Nat of P) extends Nat { type T = S[P] } + + +type :=:[T, X] = T & { type R = X } + +trait :+:[A, B] { + type R +} + +object Addition { + given AddZero[N]: :+:[Z, N] with { + type R = N + } + + given AddSucc[M, N, X](using e: (M :+: N) :=: X): :+:[S[M], N] with { + val e0: (M :+: N) :=: X = e + type R = S[X] + } +} + +object Proof { + import Addition.given + + def zeroAddN[N]: (Z :+: N) :=: N = summon + + def nAddZero[N](n: Nat of N): (N :+: Z) :=: N = n match { + case Zero() => zeroAddN[Z] + case p: Succ[pn] => + val e0: (pn :+: Z) :=: pn = nAddZero[pn](p.prec) + AddSucc(using e0) + } +} + diff --git a/tests/pos/pdgadt-path.scala b/tests/pos/pdgadt-path.scala new file mode 100644 index 000000000000..1b9b4e8fc9c9 --- /dev/null +++ b/tests/pos/pdgadt-path.scala @@ -0,0 +1,8 @@ +trait Expr +case class IntLit(b: Int) extends Expr + +def foo[M <: Expr](e: M) = e match { + case e1: IntLit => + val t0: e1.type = e + val t1: e.type = e1 +} diff --git a/tests/pos/pdgadt-sub.scala b/tests/pos/pdgadt-sub.scala new file mode 100644 index 000000000000..6b29ac0bb306 --- /dev/null +++ b/tests/pos/pdgadt-sub.scala @@ -0,0 +1,21 @@ +trait SubBase { + type A0 + type B0 + type C >: A0 <: B0 +} +trait Tag { type A } + +type Sub[A, B] = SubBase { type A0 = A; type B0 = B } + +def foo(x: Tag, y: Tag, e: Sub[x.A, y.A]) = e.match { + case _: Object => + val t1: y.A = ??? : x.A +} + +def bar(x: Tag, e: Sub[Int, x.A]): x.A = e match { + case _: Object => 0 +} + +def baz(x: Tag, e: Sub[x.A, Int]): Int = e match { + case _: Object => ??? : x.A +} diff --git a/tests/pos/pdgadt-wildcard.scala b/tests/pos/pdgadt-wildcard.scala new file mode 100644 index 000000000000..908ef4ff4fea --- /dev/null +++ b/tests/pos/pdgadt-wildcard.scala @@ -0,0 +1,8 @@ +trait Expr { type T } +case class Inv[X](x: X) extends Expr { type T = X } +case class Inv2[X](x: X) extends Expr { type T >: X } + +def eval(e: Expr): e.T = e match + case Inv(x) => x + case Inv2(x) => x +