From 876910f7a4c0c585d7d2ecea5570021d99b4acea Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 23 Aug 2021 01:04:08 +0800 Subject: [PATCH 01/15] Change to use TypeRef to represent type parameters --- .../tools/dotc/core/GadtConstraint.scala | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 7d84b9892057..e3574de66d7c 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -59,8 +59,8 @@ sealed abstract class GadtConstraint extends Showable { final class ProperGadtConstraint private( private var myConstraint: Constraint, - private var mapping: SimpleIdentityMap[Symbol, TypeVar], - private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], + private var mapping: SimpleIdentityMap[TypeRef, TypeVar], + private var reverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -97,7 +97,7 @@ final class ProperGadtConstraint private( case tp: NamedType => params.indexOf(tp.symbol) match { case -1 => - mapping(tp.symbol) match { + mapping(tp.symbol.typeRef) match { case tv: TypeVar => tv.origin case null => tp } @@ -118,8 +118,8 @@ final class ProperGadtConstraint private( val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) => val tv = TypeVar(paramRef, creatorState = null) - mapping = mapping.updated(sym, tv) - reverseMapping = reverseMapping.updated(tv.origin, sym) + mapping = mapping.updated(sym.typeRef, tv) + reverseMapping = reverseMapping.updated(tv.origin, sym.typeRef) tv } @@ -145,7 +145,7 @@ final class ProperGadtConstraint private( val internalizedBound = bound match { case nt: NamedType => - val ntTvar = mapping(nt.symbol) + val ntTvar = mapping(nt.symbol.typeRef) if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound case _ => bound } @@ -169,7 +169,7 @@ final class ProperGadtConstraint private( constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) override def fullBounds(sym: Symbol)(using Context): TypeBounds = - mapping(sym) match { + mapping(sym.typeRef) match { case null => null case tv => fullBounds(tv.origin) @@ -177,13 +177,13 @@ final class ProperGadtConstraint private( } override def bounds(sym: Symbol)(using Context): TypeBounds = - mapping(sym) match { + mapping(sym.typeRef) match { case null => null case tv => def retrieveBounds: TypeBounds = bounds(tv.origin) match { case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => - TypeAlias(reverseMapping(tpr).typeRef) + TypeAlias(reverseMapping(tpr)) case tb => tb } retrieveBounds @@ -191,7 +191,7 @@ final class ProperGadtConstraint private( //.ensuring(containsNoInternalTypes(_)) } - override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) ne null + override def contains(sym: Symbol)(using Context): Boolean = mapping(sym.typeRef) ne null override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = { val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow) @@ -249,12 +249,12 @@ final class ProperGadtConstraint private( private def externalize(param: TypeParamRef)(using Context): Type = reverseMapping(param) match { - case sym: Symbol => sym.typeRef + case tpr: TypeRef => tpr case null => param } private def tvarOrError(sym: Symbol)(using Context): TypeVar = - mapping(sym).ensuring(_ ne null, i"not a constrainable symbol: $sym") + mapping(sym.typeRef).ensuring(_ ne null, i"not a constrainable symbol: $sym") private def containsNoInternalTypes( tp: Type, @@ -280,8 +280,8 @@ final class ProperGadtConstraint private( val sb = new mutable.StringBuilder sb ++= constraint.show sb += '\n' - mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${fullBounds(sym)}\n" + mapping.foreachBinding { case (tpr, _) => + sb ++= i"$tpr: ${fullBounds(tpr.symbol)}\n" } sb.result } From cd9d872527db65bfb11492676e79c297824a491f Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 23 Aug 2021 01:04:22 +0800 Subject: [PATCH 02/15] Add function to retrieve all type member symbols --- compiler/src/dotty/tools/dotc/core/TypeOps.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 75a5816c3164..c4128341878a 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -818,4 +818,7 @@ object TypeOps: def nestedPairs(ts: List[Type])(using Context): Type = ts.foldRight(defn.EmptyTupleModule.termRef: Type)(defn.PairClass.typeRef.appliedTo(_, _)) + def abstractTypeMemberSymbols(tp: Type)(using Context): List[Symbol] = + tp.abstractTypeMembers.toList map (_.symbol) + end TypeOps From 8ceed62e8f9ae3c16d3044e08d69fe965269decb Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 23 Aug 2021 22:14:48 +0800 Subject: [PATCH 03/15] Add GADT logic for handling general typerefs --- .../tools/dotc/core/GadtConstraint.scala | 140 ++++++++++++++++-- .../dotty/tools/dotc/core/TypeComparer.scala | 3 + 2 files changed, 130 insertions(+), 13 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index e3574de66d7c..7d9b99d7fd6d 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -9,6 +9,7 @@ import Symbols._ import util.SimpleIdentityMap import collection.mutable import printing._ +import TypeOps.abstractTypeMemberSymbols import scala.annotation.internal.sharable @@ -16,6 +17,7 @@ import scala.annotation.internal.sharable sealed abstract class GadtConstraint extends Showable { /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ def bounds(sym: Symbol)(using Context): TypeBounds + def bounds(tp: TypeRef)(using Context): TypeBounds /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. * @@ -23,6 +25,7 @@ sealed abstract class GadtConstraint extends Showable { * Using this in isSubType can lead to infinite recursion. Consider `bounds` instead. */ def fullBounds(sym: Symbol)(using Context): TypeBounds + def fullBounds(tp: TypeRef)(using Context): TypeBounds /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean @@ -36,12 +39,18 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean /** Is the symbol registered in the constraint? * * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. */ def contains(sym: Symbol)(using Context): Boolean + def contains(tp: TypeRef)(using Context): Boolean + + /** Is the type a constrainable path-dependent type? + */ + def isConstrainablePDT(tp: Type)(using Context): Boolean def isEmpty: Boolean final def nonEmpty: Boolean = !isEmpty @@ -61,13 +70,15 @@ final class ProperGadtConstraint private( private var myConstraint: Constraint, private var mapping: SimpleIdentityMap[TypeRef, TypeVar], private var reverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], + private var tempMapping: SimpleIdentityMap[Symbol, TypeVar] ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} def this() = this( myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty), mapping = SimpleIdentityMap.empty, - reverseMapping = SimpleIdentityMap.empty + reverseMapping = SimpleIdentityMap.empty, + tempMapping = SimpleIdentityMap.empty ) /** Exposes ConstraintHandling.subsumes */ @@ -79,6 +90,92 @@ final class ProperGadtConstraint private( subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) } + override def isConstrainablePDT(tp: Type)(using Context): Boolean = tp match + case tp @ TypeRef(prefix, des) => isConstrainablePath(prefix) && ! tp.symbol.is(Flags.Opaque) + case _ => false + + /** Whether type members of the given path is constrainable? + * + * Package's and module's type members will not be constrained. + */ + private def isConstrainablePath(path: Type)(using Context): Boolean = path match + case path: TermRef if !path.symbol.is(Flags.Package) && !path.symbol.is(Flags.Module) => true + case _ => false + + /** Find all constrainable type member symbols of the given type. + * + * All abstract but not opaque type members are returned. + */ + private def constrainableTypeMemberSymbols(tp: Type)(using Context) = + abstractTypeMemberSymbols(tp) filterNot (_.is(Flags.Opaque)) + + private def addTypeMembersOf(path: Type, isUnamedPattern: Boolean)(using Context): Option[Map[Symbol, TypeVar]] = + import NameKinds.DepParamName + + /** Should not place constraints on type members defined in modules. */ + if !isUnamedPattern && !isConstrainablePath(path) then return None + + val pathType = if isUnamedPattern then path else path.widen + val typeMembers = constrainableTypeMemberSymbols(pathType) + + val poly1 = PolyType(typeMembers map { s => DepParamName.fresh(s.name.toTypeName) })( + pt => typeMembers map { typeMember => + def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { + def loop(tp: Type): Type = tp match + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndOrType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp @ TypeRef(prefix, des) if prefix == pathType => + typeMembers indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp @ TypeRef(_: RecThis, des) => + typeMembers indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp: TypeRef => + mapping(tp) match { + case tv: TypeVar => tv.origin + case null => tp + } + case tp => tp + + loop(tp) + } + + val tb = typeMember.info.bounds + tb.derivedTypeBounds( + lo = substDependentSyms(tb.lo, isUpper = false), + hi = substDependentSyms(tb.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = typeMembers lazyZip poly1.paramRefs map { (sym, paramRef) => + val tv = TypeVar(paramRef, creatorState = null) + + if isUnamedPattern then + tempMapping = tempMapping.updated(sym, tv) + else + val externalType = TypeRef(path, sym) + mapping = mapping.updated(externalType, tv) + reverseMapping = reverseMapping.updated(tv.origin, externalType) + + tv + } + + def register = + addToConstraint(poly1, tvars) + .showing(i"added to constraint: [$poly1] $typeMembers%, %\n$debugBoundsDescription", gadts) + + if register then + Some(Map.from(typeMembers lazyZip tvars)) + else + None + end addTypeMembersOf + override def addToConstraint(params: List[Symbol])(using Context): Boolean = { import NameKinds.DepParamName @@ -128,7 +225,7 @@ final class ProperGadtConstraint private( .showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts) } - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + override def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean = { @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { case tv: TypeVar => val inst = constraint.instType(tv) @@ -136,10 +233,10 @@ final class ProperGadtConstraint private( case _ => tp } - val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { + val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(tpr)) match { case tv: TypeVar => tv case inst => - gadts.println(i"instantiated: $sym -> $inst") + gadts.println(i"instantiated: $tpr -> $inst") return if (isUpper) isSub(inst, bound) else isSub(bound, inst) } @@ -161,23 +258,28 @@ final class ProperGadtConstraint private( ).showing({ val descr = if (isUpper) "upper" else "lower" val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $result" + i"adding $descr bound $tpr $op $bound = $result" }, gadts) } + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = + addBound(sym.typeRef, bound, isUpper) + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) - override def fullBounds(sym: Symbol)(using Context): TypeBounds = - mapping(sym.typeRef) match { + override def fullBounds(tp: TypeRef)(using Context): TypeBounds = + mapping(tp) match { case null => null case tv => fullBounds(tv.origin) // .ensuring(containsNoInternalTypes(_)) } - override def bounds(sym: Symbol)(using Context): TypeBounds = - mapping(sym.typeRef) match { + override def fullBounds(sym: Symbol)(using Context): TypeBounds = fullBounds(sym.typeRef) + + override def bounds(tp: TypeRef)(using Context): TypeBounds = + mapping(tp) match { case null => null case tv => def retrieveBounds: TypeBounds = @@ -191,7 +293,10 @@ final class ProperGadtConstraint private( //.ensuring(containsNoInternalTypes(_)) } - override def contains(sym: Symbol)(using Context): Boolean = mapping(sym.typeRef) ne null + override def bounds(sym: Symbol)(using Context): TypeBounds = bounds(sym.typeRef) + + override def contains(tp: TypeRef)(using Context): Boolean = mapping(tp) ne null + override def contains(sym: Symbol)(using Context): Boolean = contains(sym.typeRef) override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = { val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow) @@ -202,7 +307,8 @@ final class ProperGadtConstraint private( override def fresh: GadtConstraint = new ProperGadtConstraint( myConstraint, mapping, - reverseMapping + reverseMapping, + tempMapping ) def restore(other: GadtConstraint): Unit = other match { @@ -210,6 +316,7 @@ final class ProperGadtConstraint private( this.myConstraint = other.myConstraint this.mapping = other.mapping this.reverseMapping = other.reverseMapping + this.tempMapping = other.tempMapping case _ => ; } @@ -253,8 +360,10 @@ final class ProperGadtConstraint private( case null => param } - private def tvarOrError(sym: Symbol)(using Context): TypeVar = - mapping(sym.typeRef).ensuring(_ ne null, i"not a constrainable symbol: $sym") + private def tvarOrError(tpr: TypeRef)(using Context): TypeVar = + mapping(tpr).ensuring(_ ne null, i"not a constrainable type: $tpr") + + private def tvarOrError(sym: Symbol)(using Context): TypeVar = tvarOrError(sym.typeRef) private def containsNoInternalTypes( tp: Type, @@ -290,14 +399,19 @@ final class ProperGadtConstraint private( @sharable object EmptyGadtConstraint extends GadtConstraint { override def bounds(sym: Symbol)(using Context): TypeBounds = null override def fullBounds(sym: Symbol)(using Context): TypeBounds = null + override def bounds(tp: TypeRef)(using Context): TypeBounds = null + override def fullBounds(tp: TypeRef)(using Context): TypeBounds = null override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isEmpty: Boolean = true override def contains(sym: Symbol)(using Context) = false + override def contains(tp: TypeRef)(using Context) = false + override def isConstrainablePDT(tp: Type)(using Context): Boolean = false override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") + override def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index b568cb2c8af8..3c8680d81e49 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -117,6 +117,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym) protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false) protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true) + protected def gadtBounds(tpr: TypeRef)(using Context) = ctx.gadt.bounds(tpr) + protected def gadtAddLowerBound(tpr: TypeRef, b: Type): Boolean = ctx.gadt.addBound(tpr, b, isUpper = false) + protected def gadtAddUpperBound(tpr: TypeRef, b: Type): Boolean = ctx.gadt.addBound(tpr, b, isUpper = true) protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying From e9150dc2650b53cf972b96c5f4d82cd4a96f0b68 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Tue, 24 Aug 2021 01:40:15 +0800 Subject: [PATCH 04/15] Implement GADT type inference for path-dependent types --- .../tools/dotc/core/GadtConstraint.scala | 15 +++++- .../dotty/tools/dotc/core/TypeComparer.scala | 53 +++++++++++++------ 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 7d9b99d7fd6d..878c1e1cbd67 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -52,6 +52,10 @@ sealed abstract class GadtConstraint extends Showable { */ def isConstrainablePDT(tp: Type)(using Context): Boolean + /** Add path-dependent type to constraint. + */ + def addPDT(tp: Type)(using Context): Boolean + def isEmpty: Boolean final def nonEmpty: Boolean = !isEmpty @@ -102,6 +106,12 @@ final class ProperGadtConstraint private( case path: TermRef if !path.symbol.is(Flags.Package) && !path.symbol.is(Flags.Module) => true case _ => false + override def addPDT(tp: Type)(using Context): Boolean = + assert(isConstrainablePDT(tp), i"Type $tp is not a constrainable path-dependent type.") + tp match + case TypeRef(prefix: TermRef, _) => addTypeMembersOf(prefix, isUnamedPattern = false).nonEmpty + case _ => false + /** Find all constrainable type member symbols of the given type. * * All abstract but not opaque type members are returned. @@ -118,6 +128,8 @@ final class ProperGadtConstraint private( val pathType = if isUnamedPattern then path else path.widen val typeMembers = constrainableTypeMemberSymbols(pathType) + if typeMembers.isEmpty then return Some(Map.empty) + val poly1 = PolyType(typeMembers map { s => DepParamName.fresh(s.name.toTypeName) })( pt => typeMembers map { typeMember => def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { @@ -390,7 +402,7 @@ final class ProperGadtConstraint private( sb ++= constraint.show sb += '\n' mapping.foreachBinding { case (tpr, _) => - sb ++= i"$tpr: ${fullBounds(tpr.symbol)}\n" + sb ++= i"$tpr: ${fullBounds(tpr)}\n" } sb.result } @@ -409,6 +421,7 @@ final class ProperGadtConstraint private( override def contains(sym: Symbol)(using Context) = false override def contains(tp: TypeRef)(using Context) = false override def isConstrainablePDT(tp: Type)(using Context): Boolean = false + override def addPDT(tp: Type)(using Context): Boolean = false override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") override def addBound(tpr: TypeRef, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 3c8680d81e49..92a16b872671 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -182,6 +182,22 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private inline def inFrozenGadtAndConstraint[T](inline op: T): T = inFrozenGadtIf(true)(inFrozenConstraint(op)) + private def canRegisterPDT: Boolean = + ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint + + private def tryRegisterPDT(tpr: TypeRef): Boolean = + canRegisterPDT + && ctx.gadt.isConstrainablePDT(tpr) + && ctx.gadt.addPDT(tpr) + + extension (tpr: TypeRef) + private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = + def useGadtBounds = + val bounds = gadtBounds(tpr) + bounds != null && op(bounds) + + useGadtBounds || { tryRegisterPDT(tpr) && useGadtBounds } + extension (sym: Symbol) private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = val bounds = gadtBounds(sym) @@ -503,20 +519,23 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match { case info2: TypeBounds => - def compareGADT: Boolean = - tp2.symbol.onGadtBounds(gbounds2 => - isSubTypeWhenFrozen(tp1, gbounds2.lo) - || tp1.match - case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => - // Note: since we approximate constrained types only with their non-param bounds, - // we need to manually handle the case when we're comparing two constrained types, - // one of which is constrained to be a subtype of another. - // We do not need similar code in fourthTry, since we only need to care about - // comparing two constrained types, and that case will be handled here first. - ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) - case _ => false - || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) - && (isBottom(tp1) || GADTusage(tp2.symbol)) + def compareGADT: Boolean = tp2 match + case tp2: TypeRef => + tp2.onGadtBounds(gbounds2 => + isSubTypeWhenFrozen(tp1, gbounds2.lo) + || tp1.match + case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false + || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) + && (isBottom(tp1) || GADTusage(tp2.symbol)) + case _ => false + end compareGADT isSubApproxHi(tp1, info2.lo) || compareGADT || tryLiftedToThis2 || fourthTry @@ -1881,14 +1900,14 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * `bound` as an upper or lower bound (which depends on `isUpper`). * Test that the resulting bounds are still satisfiable. */ - private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { + private def narrowGADTBounds(tr: TypeRef, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && { val tparam = tr.symbol gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") if (bound.isRef(tparam)) false - else if (isUpper) gadtAddUpperBound(tparam, bound) - else gadtAddLowerBound(tparam, bound) + else if (isUpper) gadtAddUpperBound(tr, bound) + else gadtAddLowerBound(tr, bound) } } From 77a879ffab0655106ae574578c9d1f39db3fe600 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Tue, 24 Aug 2021 01:44:36 +0800 Subject: [PATCH 05/15] Support path-dependent GADT reasoning for upper bounds --- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 92a16b872671..307c9bfdec01 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -787,7 +787,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.info match { case TypeBounds(_, hi1) => def compareGADT = - tp1.symbol.onGadtBounds(gbounds1 => + tp1.onGadtBounds(gbounds1 => isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) && (tp2.isAny || GADTusage(tp1.symbol)) From f060c9b20731ea5f4f68d254c82f3b97f7573c6d Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 25 Aug 2021 01:29:05 +0800 Subject: [PATCH 06/15] Fix type comparison triggered by onGadtBounds --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 1 - compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 878c1e1cbd67..e8237573a848 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -122,7 +122,6 @@ final class ProperGadtConstraint private( private def addTypeMembersOf(path: Type, isUnamedPattern: Boolean)(using Context): Option[Map[Symbol, TypeVar]] = import NameKinds.DepParamName - /** Should not place constraints on type members defined in modules. */ if !isUnamedPattern && !isConstrainablePath(path) then return None val pathType = if isUnamedPattern then path else path.widen diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 307c9bfdec01..8c3518479926 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -192,11 +192,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling extension (tpr: TypeRef) private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = - def useGadtBounds = - val bounds = gadtBounds(tpr) - bounds != null && op(bounds) + val gbounds = + gadtBounds(tpr) match + case null => + if tryRegisterPDT(tpr) then gadtBounds(tpr) else null + case gbounds => gbounds - useGadtBounds || { tryRegisterPDT(tpr) && useGadtBounds } + gbounds != null && op(gbounds) extension (sym: Symbol) private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = From 0f5206786ca269b43a96128efaeb9745a2e00277 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 1 Sep 2021 00:00:22 +0800 Subject: [PATCH 07/15] Add footprint tracking for TypeRefs --- .../src/dotty/tools/dotc/core/TypeComparer.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 8c3518479926..01a814318a19 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2853,16 +2853,31 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.gadtBounds(sym) } + override def gadtBounds(tpr: TypeRef)(using Context): TypeBounds = { + if (tpr.exists) footprint += tpr + super.gadtBounds(tpr) + } + override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = { if (sym.exists) footprint += sym.typeRef super.gadtAddLowerBound(sym, b) } + override def gadtAddLowerBound(tpr: TypeRef, b: Type): Boolean = { + if (tpr.exists) footprint += tpr + super.gadtAddLowerBound(tpr, b) + } + override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = { if (sym.exists) footprint += sym.typeRef super.gadtAddUpperBound(sym, b) } + override def gadtAddUpperBound(tpr: TypeRef, b: Type): Boolean = { + if (tpr.exists) footprint += tpr + super.gadtAddUpperBound(tpr, b) + } + override def typeVarInstance(tvar: TypeVar)(using Context): Type = { footprint += tvar super.typeVarInstance(tvar) From c1b4b8cdff7d3b54019eef4ebff4903430547806 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 1 Sep 2021 11:31:34 +0800 Subject: [PATCH 08/15] Adapt addLess for TypeRefs --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 7 ++++++- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index e8237573a848..6b97ccb16d65 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -29,6 +29,7 @@ sealed abstract class GadtConstraint extends Showable { /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean + def isLess(tpr1: TypeRef, tpr2: TypeRef)(using Context): Boolean /** Add symbols to constraint, correctly handling inter-dependencies. * @@ -277,7 +278,10 @@ final class ProperGadtConstraint private( addBound(sym.typeRef, bound, isUpper) override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = - constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + isLess(sym1.typeRef, sym2.typeRef) + + override def isLess(tp1: TypeRef, tp2: TypeRef)(using Context): Boolean = + constraint.isLess(tvarOrError(tp1).origin, tvarOrError(tp2).origin) override def fullBounds(tp: TypeRef)(using Context): TypeBounds = mapping(tp) match { @@ -414,6 +418,7 @@ final class ProperGadtConstraint private( override def fullBounds(tp: TypeRef)(using Context): TypeBounds = null override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + override def isLess(tp1: TypeRef, tp2: TypeRef)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isEmpty: Boolean = true diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 01a814318a19..3cb2d80605d5 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -526,13 +526,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp2.onGadtBounds(gbounds2 => isSubTypeWhenFrozen(tp1, gbounds2.lo) || tp1.match - case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => + case tp1: TypeRef if ctx.gadt.contains(tp1) => // Note: since we approximate constrained types only with their non-param bounds, // we need to manually handle the case when we're comparing two constrained types, // one of which is constrained to be a subtype of another. // We do not need similar code in fourthTry, since we only need to care about // comparing two constrained types, and that case will be handled here first. - ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + ctx.gadt.isLess(tp1, tp2) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) case _ => false || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) && (isBottom(tp1) || GADTusage(tp2.symbol)) From de5c91f25e95a0a40eeab8b9fc21b35d8f77206e Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 1 Sep 2021 11:31:58 +0800 Subject: [PATCH 09/15] Register bound if it is a PDT --- .../src/dotty/tools/dotc/core/TypeComparer.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 3cb2d80605d5..dee87c701b4e 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1906,8 +1906,16 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && { val tparam = tr.symbol - gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") - if (bound.isRef(tparam)) false + gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam) && !ctx.gadt.isConstrainablePDT(bound)}") + + def registerPDTBound(): Boolean = bound match + case bound: TypeRef => + ctx.gadt.isConstrainablePDT(bound) && !ctx.gadt.contains(bound) && tryRegisterPDT(bound) + case _ => false + + registerPDTBound() + + if (bound.isRef(tparam) && !ctx.gadt.isConstrainablePDT(bound)) false else if (isUpper) gadtAddUpperBound(tr, bound) else gadtAddLowerBound(tr, bound) } From 08555286dbc654cc1c717aa0980e27eeda92435a Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 2 Sep 2021 14:43:01 +0800 Subject: [PATCH 10/15] Add basic testcase for path-dependent GADT --- tests/pos/basic-pdgadt.scala | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/pos/basic-pdgadt.scala diff --git a/tests/pos/basic-pdgadt.scala b/tests/pos/basic-pdgadt.scala new file mode 100644 index 000000000000..cd44954533ce --- /dev/null +++ b/tests/pos/basic-pdgadt.scala @@ -0,0 +1,20 @@ +enum SUB[-A, +B]: + case Ev[X]() extends SUB[X, X] + +trait Tag { type T } + +def f(p: Tag, e: SUB[Int, p.T]): p.T = e match + case SUB.Ev() => 0 + +def g(p: Tag, q: Tag, e: SUB[p.T, q.T]) = e match + case SUB.Ev() => + // p.T <: q.T + (??? : p.T) : q.T + +def h1[Q](p: Tag, e: SUB[p.T, Q]) = e match + case SUB.Ev() => + (??? : p.T) : Q + +def h2[P](q: Tag, e: SUB[P, q.T]) = e match + case SUB.Ev() => + (??? : P) : q.T From 5b76853f975a6de87003902004fc88ab5d98acaa Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 2 Sep 2021 14:46:36 +0800 Subject: [PATCH 11/15] Update subsumption check considering just-in-time registration --- .../tools/dotc/core/GadtConstraint.scala | 96 ++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 6b97ccb16d65..333a621b154b 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -86,13 +86,101 @@ final class ProperGadtConstraint private( tempMapping = SimpleIdentityMap.empty ) - /** Exposes ConstraintHandling.subsumes */ + /** Whether `left` subsumes `right`? + * + * `left` and `right` both stem from the constraint `pre`, with different type reasoning performed, + * during which new types might be registered in GadtConstraint. This function will take such newly + * registered types into consideration. + */ def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { + // When new types are registered after pre, for left to subsume right, it should contain all types + // newly registered in right. + def checkSubsumes(c1: Constraint, c2: Constraint, pre: Constraint): Boolean = { + if (c2 eq pre) true + else if (c1 eq pre) false + else { + val saved = constraint + + /** Compute type parameters in c1 added after `pre` + */ + val params1 = c1.domainParams.toSet + val params2 = c2.domainParams.toSet + val preParams = pre.domainParams.toSet + val newParams1 = params1.diff(preParams) + val newParams2 = params2.diff(preParams) + + def checkNewParams: Boolean = (left, right) match { + case (left: ProperGadtConstraint, right: ProperGadtConstraint) => + newParams2 forall { p2 => + val tp2 = right.externalize(p2) + left.tvarOfType(tp2) != null + } + case _ => true + } + + // bridge between the newly-registered types in c2 and c1 + val (bridge1, bridge2) = { + var bridge1: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + var bridge2: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + + (left, right) match { + // only meaningful when both constraints are proper + case (left: ProperGadtConstraint, right: ProperGadtConstraint) => + newParams1 foreach { p1 => + val tp1 = left.externalize(p1) + right.tvarOfType(tp1) match { + case null => + case tvar2 => + bridge1 = bridge1.updated(p1, tvar2.origin) + bridge2 = bridge2.updated(tvar2.origin, p1) + } + } + case _ => + } + + (bridge1, bridge2) + } + + def bridgeParam(bridge: SimpleIdentityMap[TypeParamRef, TypeParamRef])(tpr: TypeParamRef): TypeParamRef = bridge(tpr) match { + case null => tpr + case tpr1 => tpr1 + } + val bridgeParam1 = bridgeParam(bridge1) + val bridgeParam2 = bridgeParam(bridge2) + + try { + // checks existing type parameters in `pre` + def existing: Boolean = pre.forallParams { p => + c1.contains(p) && + c2.upper(p).forall { q => + c1.isLess(p, bridgeParam2(q)) + } && isSubTypeWhenFrozen(c1.nonParamBounds(p), c2.nonParamBounds(p)) + } + + // checks new type parameters in `c1` + def added: Boolean = newParams1 forall { p1 => + bridge1(p1) match { + case null => + // p1 is in `left` but not in `right` + true + case p2 => + c2.upper(p2).forall { q => + c1.isLess(p1, bridgeParam2(q)) + } && isSubTypeWhenFrozen(c1.nonParamBounds(p1), c2.nonParamBounds(p2)) + } + } + + existing && checkNewParams && added + } finally constraint = saved + } + } + def extractConstraint(g: GadtConstraint) = g match { case s: ProperGadtConstraint => s.constraint case EmptyGadtConstraint => OrderingConstraint.empty } - subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + + checkSubsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) } override def isConstrainablePDT(tp: Type)(using Context): Boolean = tp match @@ -375,6 +463,10 @@ final class ProperGadtConstraint private( case null => param } + private def tvarOfType(tp: Type)(using Context): TypeVar = tp match + case tp: TypeRef => mapping(tp) + case _ => null + private def tvarOrError(tpr: TypeRef)(using Context): TypeVar = mapping(tpr).ensuring(_ ne null, i"not a constrainable type: $tpr") From ee3febea1ce2c06968a850925bb9f250cda53bb2 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 2 Sep 2021 14:46:55 +0800 Subject: [PATCH 12/15] Add pos and neg tests for necessaryEither --- tests/neg/necessary-pdgadt.scala | 11 +++++++++++ tests/pos/necessary-pdgadt.scala | 14 ++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/neg/necessary-pdgadt.scala create mode 100644 tests/pos/necessary-pdgadt.scala diff --git a/tests/neg/necessary-pdgadt.scala b/tests/neg/necessary-pdgadt.scala new file mode 100644 index 000000000000..3be1f8337fa9 --- /dev/null +++ b/tests/neg/necessary-pdgadt.scala @@ -0,0 +1,11 @@ +/* N <: M */ +trait M +trait N + +enum SUB[-A, +B]: + case Ev[X]() extends SUB[X, X] +trait P { type T } + +def f(p: P, e: SUB[p.T, N | M]) = e match + case SUB.Ev() => + (??? : p.T) : N // error diff --git a/tests/pos/necessary-pdgadt.scala b/tests/pos/necessary-pdgadt.scala new file mode 100644 index 000000000000..eb58f952ce65 --- /dev/null +++ b/tests/pos/necessary-pdgadt.scala @@ -0,0 +1,14 @@ +/* N <: M */ +trait M +trait N extends M + +enum SUB[-A, +B]: + case Ev[X]() extends SUB[X, X] + +trait P { type T } + +def f(p: P, e: SUB[p.T, N | M]) = e match + case SUB.Ev() => + // p.T <: M + (??? : p.T) : M + From ebfcb0e7a598f5453f3b79cbd28c433e97fe7cf5 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sat, 4 Sep 2021 16:57:18 +0800 Subject: [PATCH 13/15] Remove useless code handling unbound patterns --- .../tools/dotc/core/GadtConstraint.scala | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 333a621b154b..1eec4bbb942c 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -75,15 +75,13 @@ final class ProperGadtConstraint private( private var myConstraint: Constraint, private var mapping: SimpleIdentityMap[TypeRef, TypeVar], private var reverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], - private var tempMapping: SimpleIdentityMap[Symbol, TypeVar] ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} def this() = this( myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty), mapping = SimpleIdentityMap.empty, - reverseMapping = SimpleIdentityMap.empty, - tempMapping = SimpleIdentityMap.empty + reverseMapping = SimpleIdentityMap.empty ) /** Whether `left` subsumes `right`? @@ -198,7 +196,7 @@ final class ProperGadtConstraint private( override def addPDT(tp: Type)(using Context): Boolean = assert(isConstrainablePDT(tp), i"Type $tp is not a constrainable path-dependent type.") tp match - case TypeRef(prefix: TermRef, _) => addTypeMembersOf(prefix, isUnamedPattern = false).nonEmpty + case TypeRef(prefix: TermRef, _) => addTypeMembersOf(prefix).nonEmpty case _ => false /** Find all constrainable type member symbols of the given type. @@ -208,12 +206,12 @@ final class ProperGadtConstraint private( private def constrainableTypeMemberSymbols(tp: Type)(using Context) = abstractTypeMemberSymbols(tp) filterNot (_.is(Flags.Opaque)) - private def addTypeMembersOf(path: Type, isUnamedPattern: Boolean)(using Context): Option[Map[Symbol, TypeVar]] = + private def addTypeMembersOf(path: Type)(using Context): Option[Map[Symbol, TypeVar]] = import NameKinds.DepParamName - if !isUnamedPattern && !isConstrainablePath(path) then return None + if !isConstrainablePath(path) then return None - val pathType = if isUnamedPattern then path else path.widen + val pathType = path.widen val typeMembers = constrainableTypeMemberSymbols(pathType) if typeMembers.isEmpty then return Some(Map.empty) @@ -256,12 +254,9 @@ final class ProperGadtConstraint private( val tvars = typeMembers lazyZip poly1.paramRefs map { (sym, paramRef) => val tv = TypeVar(paramRef, creatorState = null) - if isUnamedPattern then - tempMapping = tempMapping.updated(sym, tv) - else - val externalType = TypeRef(path, sym) - mapping = mapping.updated(externalType, tv) - reverseMapping = reverseMapping.updated(tv.origin, externalType) + val externalType = TypeRef(path, sym) + mapping = mapping.updated(externalType, tv) + reverseMapping = reverseMapping.updated(tv.origin, externalType) tv } @@ -410,8 +405,7 @@ final class ProperGadtConstraint private( override def fresh: GadtConstraint = new ProperGadtConstraint( myConstraint, mapping, - reverseMapping, - tempMapping + reverseMapping ) def restore(other: GadtConstraint): Unit = other match { @@ -419,7 +413,6 @@ final class ProperGadtConstraint private( this.myConstraint = other.myConstraint this.mapping = other.mapping this.reverseMapping = other.reverseMapping - this.tempMapping = other.tempMapping case _ => ; } From 32416491bc94d2370194038869e923826888aa9e Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sat, 4 Sep 2021 22:35:01 +0800 Subject: [PATCH 14/15] Add pos test for interdependency between type members --- tests/pos/interdep-pdgadt.scala | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/pos/interdep-pdgadt.scala diff --git a/tests/pos/interdep-pdgadt.scala b/tests/pos/interdep-pdgadt.scala new file mode 100644 index 000000000000..c7fcc3b8e782 --- /dev/null +++ b/tests/pos/interdep-pdgadt.scala @@ -0,0 +1,8 @@ +trait P { type S; type T >: S } + +enum SUB[-A, +B]: + case EQ[X]() extends SUB[X, X] + +def f(p: P, e: SUB[Int, p.S]): p.T = e match + case SUB.EQ() => 42 + From 5120178cc138b6b67472827380f4e5a6583f67f7 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Tue, 28 Sep 2021 11:17:05 +0800 Subject: [PATCH 15/15] Support PDTs from ThisType --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 1eec4bbb942c..cc7498fa7d34 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -191,12 +191,14 @@ final class ProperGadtConstraint private( */ private def isConstrainablePath(path: Type)(using Context): Boolean = path match case path: TermRef if !path.symbol.is(Flags.Package) && !path.symbol.is(Flags.Module) => true + case path: ThisType if !path.cls.is(Flags.Package) && !path.cls.is(Flags.Module) => true case _ => false override def addPDT(tp: Type)(using Context): Boolean = assert(isConstrainablePDT(tp), i"Type $tp is not a constrainable path-dependent type.") tp match case TypeRef(prefix: TermRef, _) => addTypeMembersOf(prefix).nonEmpty + case TypeRef(prefix: ThisType, _) => addTypeMembersOf(prefix).nonEmpty case _ => false /** Find all constrainable type member symbols of the given type.