From 9104c23366020ac1d5e2197ae8c10cda1507d65c Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 17 Mar 2022 01:52:10 +0800 Subject: [PATCH 01/56] save path for scrutinee --- .../tools/dotc/core/GadtConstraint.scala | 29 +++++++++++++++++-- .../src/dotty/tools/dotc/typer/Typer.scala | 16 +++++++++- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index d8e1c5276ab6..a1fab093d39c 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -37,6 +37,12 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + /** Scrutinee path of the current pattern matching. */ + def scrutineePath: TermRef | Null + + /** Set the scrutinee path. */ + def withScrutineePath[T](path: TermRef)(op: => T): T + /** Is the symbol registered in the constraint? * * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. @@ -63,7 +69,8 @@ final class ProperGadtConstraint private( private var myConstraint: Constraint, private var mapping: SimpleIdentityMap[Symbol, TypeVar], private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], - private var wasConstrained: Boolean + private var wasConstrained: Boolean, + private var myScrutineePath: TermRef ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -71,7 +78,8 @@ final class ProperGadtConstraint private( myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty), mapping = SimpleIdentityMap.empty, reverseMapping = SimpleIdentityMap.empty, - wasConstrained = false + wasConstrained = false, + myScrutineePath = null ) /** Exposes ConstraintHandling.subsumes */ @@ -225,7 +233,8 @@ final class ProperGadtConstraint private( myConstraint, mapping, reverseMapping, - wasConstrained + wasConstrained, + myScrutineePath ) def restore(other: GadtConstraint): Unit = other match { @@ -234,9 +243,19 @@ final class ProperGadtConstraint private( this.mapping = other.mapping this.reverseMapping = other.reverseMapping this.wasConstrained = other.wasConstrained + this.myScrutineePath = other.myScrutineePath case _ => ; } + override def scrutineePath: TermRef | Null = myScrutineePath + + override def withScrutineePath[T](path: TermRef)(op: => T): T = + val saved = this.myScrutineePath + this.myScrutineePath = path + val result = op + this.myScrutineePath = saved + result + // ---- Protected/internal ----------------------------------------------- override protected def constraint = myConstraint @@ -321,6 +340,10 @@ final class ProperGadtConstraint private( override def symbols: List[Symbol] = Nil + override def scrutineePath: TermRef | Null = unsupported("EmptyGadtConstraint.scrutineePath") + + override def withScrutineePath[T](path: TermRef)(op: => T): T = op + override def fresh = new ProperGadtConstraint override def restore(other: GadtConstraint): Unit = assert(!other.isNarrowing, "cannot restore a non-empty GADTMap") diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index a73c73863606..58aa36ebf518 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1750,7 +1750,21 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1) } - val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx) + val scrutineePath = + sel.tpe match { + case p: TermRef => + tree.pat match { + case _: Trees.Typed[_] => p + case _: Trees.Ident[_] => p + case _: Trees.Apply[_] => p + case _ => null + } + case _ => null + } + + val pat1 = gadtCtx.gadt.withScrutineePath(scrutineePath) { + typedPattern(tree.pat, wideSelType)(using gadtCtx) + } caseRest(pat1)( using Nullables.caseContext(sel, pat1)( using gadtCtx)) From 974e67a8383c5b7665407881edf409c1201cdae7 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 17 Mar 2022 01:52:19 +0800 Subject: [PATCH 02/56] update --- .../dotty/tools/dotc/core/PatternTypeConstrainer.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index c5f126580df5..efdbc15ee803 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer => * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables, * in which case the subtyping relationship "heals" the type. */ - def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) { + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace.force(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts) { def classesMayBeCompatible: Boolean = { import Flags._ @@ -196,6 +196,13 @@ trait PatternTypeConstrainer { self: TypeComparer => } } + /** Show the scrutinee. Will show the path if available. */ + private def scrutRepr(scrut: Type): String = + if ctx.gadt.scrutineePath != null then + ctx.gadt.scrutineePath.show + else + scrut.show + /** Constrain "simple" patterns (see `constrainPatternType`). * * This function expects to receive two types (scrutinee and pattern), both From 7e655d07b22b6cc4d1e30cd2a65b1d709389fa43 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Tue, 9 Aug 2022 16:43:26 +0800 Subject: [PATCH 03/56] checkpoint: before actually narrowing GADT bounds --- .../tools/dotc/core/GadtConstraint.scala | 314 +++++++++++++++++- .../dotc/core/PatternTypeConstrainer.scala | 86 +++-- .../dotty/tools/dotc/core/TypeComparer.scala | 94 +++++- 3 files changed, 459 insertions(+), 35 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index a1fab093d39c..0ba31ad542c2 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -11,12 +11,21 @@ import collection.mutable import printing._ import scala.annotation.internal.sharable +import Denotations.Denotation + +/** Types that represent a path. Can either be a TermRef or a SkolemType. */ +type PathType = TermRef | SkolemType /** Represents GADT constraints currently in scope */ sealed abstract class GadtConstraint extends Showable { /** Immediate bounds of `sym`. Does not contain lower/upper symbols (see [[fullBounds]]). */ def bounds(sym: Symbol)(using Context): TypeBounds | Null + /** Immediate bounds of a path-dependent type. + * This variant of bounds will ONLY try to retrieve path-dependent GADT bounds. */ + def bounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null + def bounds(tp: TypeRef)(using Context): TypeBounds | Null + /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. * * @note this performs subtype checks between ordered symbols. @@ -24,9 +33,15 @@ sealed abstract class GadtConstraint extends Showable { */ def fullBounds(sym: Symbol)(using Context): TypeBounds | Null + def fullBounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null + def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null + /** Is `sym1` ordered to be less than `sym2`? */ def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean + /** Is `tp1` ordered to be less than `tp2`? */ + def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean + /** Add symbols to constraint, correctly handling inter-dependencies. * * @see [[ConstraintHandling.addToConstraint]] @@ -34,12 +49,21 @@ sealed abstract class GadtConstraint extends Showable { def addToConstraint(syms: List[Symbol])(using Context): Boolean def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil) + /** Add path to constraint, registering all its abstract type members. */ + def addToConstraint(path: PathType)(using Context): Boolean + /** Further constrain a symbol already present in the constraint. */ def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + /** Further constrain a path-dependent type already present in the constraint. */ + def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + /** Scrutinee path of the current pattern matching. */ def scrutineePath: TermRef | Null + /** Reset scrutinee path to null. */ + def resetScrutineePath(): Unit + /** Set the scrutinee path. */ def withScrutineePath[T](path: TermRef)(op: => T): T @@ -49,6 +73,14 @@ sealed abstract class GadtConstraint extends Showable { */ def contains(sym: Symbol)(using Context): Boolean + /** Checks whether a path is registered. */ + def contains(path: PathType)(using Context): Boolean + + /** Checks whether a path-dependent type is registered in the handler. */ + def contains(path: PathType, sym: Symbol)(using Context): Boolean + + def registeredTypeMembers(path: PathType): List[Symbol] + /** GADT constraint narrows bounds of at least one variable */ def isNarrowing: Boolean @@ -69,6 +101,8 @@ final class ProperGadtConstraint private( private var myConstraint: Constraint, private var mapping: SimpleIdentityMap[Symbol, TypeVar], private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol], + private var pathDepMapping: SimpleIdentityMap[PathType, SimpleIdentityMap[Symbol, TypeVar]], + private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], private var wasConstrained: Boolean, private var myScrutineePath: TermRef ) extends GadtConstraint with ConstraintHandling { @@ -78,6 +112,8 @@ final class ProperGadtConstraint private( myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentitySet.empty), mapping = SimpleIdentityMap.empty, reverseMapping = SimpleIdentityMap.empty, + pathDepMapping = SimpleIdentityMap.empty, + pathDepReverseMapping = SimpleIdentityMap.empty, wasConstrained = false, myScrutineePath = null ) @@ -96,6 +132,154 @@ final class ProperGadtConstraint private( // the case where they're valid, so no approximating is needed. rawBound + /** Whether type members of the given path is constrainable? + * + * Package's and module's type members will not be constrained. + */ + private def isConstrainablePath(path: Type)(using Context): Boolean = path match + case path: TermRef + if !path.symbol.is(Flags.Package) + && !path.symbol.is(Flags.Module) + && !path.classSymbol.is(Flags.Package) + && !path.classSymbol.is(Flags.Module) + => true + case _: SkolemType + if !path.classSymbol.is(Flags.Package) + && !path.classSymbol.is(Flags.Module) + => true + case _ => false + + /** Find all constrainable type member denotations of the given type. + * + * All abstract but not opaque type members are returned. + * Note that we return denotation here, since the bounds of the type member + * depend on the context (e.g. applied type parameters). + */ + private def constrainableTypeMembers(tp: Type)(using Context): List[Denotation] = + tp.typeMembers.toList filter { denot => + val denot1 = tp.nonPrivateMember(denot.name) + val tb = denot.info + + def isConstrainableAlias: Boolean = tb match + case TypeAlias(_) => true + case _ => false + + (denot1.symbol.is(Flags.Deferred) || isConstrainableAlias) + && !denot1.symbol.is(Flags.Opaque) + && !denot1.symbol.isClass + } + + private def tvarOf(path: PathType, sym: Symbol)(using Context): TypeVar | Null = + pathDepMapping(path) match + case null => null + case innerMapping => innerMapping(sym) + + /** Try to retrieve type variable for some TypeRef. + * Both type parameters and path-dependent types are considered. + */ + private def tvarOf(tpr: TypeRef)(using Context): TypeVar | Null = + mapping(tpr.symbol) match + case null => + tpr match + case TypeRef(p: PathType, _) => tvarOf(p, tpr.symbol) + case _ => null + case tv => tv + + private def tvarOf(ntp: NamedType)(using Context): TypeVar | Null = + ntp match + case tp: TypeRef => tvarOf(tp) + case _ => null + + override def addToConstraint(path: PathType)(using Context): Boolean = isConstrainablePath(path) && { + import NameKinds.DepParamName + val pathType = path.widen + val typeMembers = constrainableTypeMembers(path) + + gadts.println(i"> trying to add $path into constraint ...") + gadts.println(i" path.widen = $pathType") + gadts.println(i" type members =\n${debugShowTypeMembers(typeMembers)}") + + typeMembers.nonEmpty && { + val typeMemberSymbols: List[Symbol] = typeMembers map { x => x.symbol } + val poly1 = PolyType(typeMembers map { d => DepParamName.fresh(d.name.toTypeName) })( + pt => typeMembers map { typeMember => + def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { + def loop(tp: Type): Type = tp match + case tp @ AndType(tp1, tp2) if !isUpper => + tp.derivedAndOrType(loop(tp1), loop(tp2)) + case tp @ OrType(tp1, tp2) if isUpper => + tp.derivedOrType(loop(tp1), loop(tp2)) + case tp @ TypeRef(prefix, des) if prefix eq path => + typeMemberSymbols indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp @ TypeRef(_: RecThis, des) => + typeMemberSymbols indexOf tp.symbol match + case -1 => tp + case idx => pt.paramRefs(idx) + case tp: TypeRef => + tvarOf(tp) match { + case tv: TypeVar => tv.origin + case null => tp + } + case tp => tp + + loop(tp) + } + + val tb = typeMember.info.bounds + + def stripLazyRef(tp: Type): Type = tp match + case tp @ RefinedType(parent, name, tb) => + tp.derivedRefinedType(stripLazyRef(parent), name, stripLazyRef(tb)) + case tp: RecType => + tp.derivedRecType(stripLazyRef(tp.parent)) + case tb: TypeBounds => + tb.derivedTypeBounds(stripLazyRef(tb.lo), stripLazyRef(tb.hi)) + case ref: LazyRef => + ref.stripLazyRef + case _ => tp + + val tb1: TypeBounds = stripLazyRef(tb).asInstanceOf + + tb1.derivedTypeBounds( + lo = substDependentSyms(tb1.lo, isUpper = false), + hi = substDependentSyms(tb1.hi, isUpper = true) + ) + }, + pt => defn.AnyType + ) + + val tvars = typeMemberSymbols lazyZip poly1.paramRefs map { (sym, paramRef) => + val tv = TypeVar(paramRef, creatorState = null) + + val externalType = TypeRef(path, sym) + pathDepMapping = pathDepMapping.updated(path, { + val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match + case null => SimpleIdentityMap.empty + case m => m + + old.updated(sym, tv) + }) + pathDepReverseMapping = pathDepReverseMapping.updated(tv.origin, externalType) + + tv + } + + addToConstraint(poly1, tvars) + .showing(i"added to constraint: [$poly1] $path\n$debugBoundsDescription", gadts) + } + } + + private def debugShowTypeMembers(typeMembers: List[Denotation])(using Context): String = + val buf = new mutable.StringBuilder + buf ++= "{\n" + typeMembers foreach { denot => + buf ++= i" ${denot.symbol}: ${denot.info.bounds}\n" + } + buf ++= "}" + buf.result + override def addToConstraint(params: List[Symbol])(using Context): Boolean = { import NameKinds.DepParamName @@ -145,6 +329,47 @@ final class ProperGadtConstraint private( .showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts) } + override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = constraint.instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(path, sym)) match { + case tv: TypeVar => tv + case inst => + gadts.println(i"instantiated: $path.$sym -> $inst") + return if (isUpper) isSub(inst, bound) else isSub(bound, inst) + } + + val internalizedBound = bound match { + case nt: TypeRef => + val ntTvar = tvarOf(nt) + if (ntTvar != null) stripInternalTypeVar(ntTvar) else bound + case _ => bound + } + + val saved = constraint + val result = internalizedBound match + case boundTvar: TypeVar => + if (boundTvar eq symTvar) true + else if (isUpper) addLess(symTvar.origin, boundTvar.origin) + else addLess(boundTvar.origin, symTvar.origin) + case bound => + addBoundTransitively(symTvar.origin, bound, isUpper) + + gadts.println { + val descr = if (isUpper) "upper" else "lower" + val op = if (isUpper) "<:" else ">:" + i"adding $descr bound $path.$sym $op $bound = $result" + } + + if constraint ne saved then wasConstrained = true + result + } + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { case tv: TypeVar => @@ -189,6 +414,9 @@ final class ProperGadtConstraint private( override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) + override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = + constraint.isLess(tvarOrError(tp1).origin, tvarOrError(tp2).origin) + override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = mapping(sym) match { case null => null @@ -198,6 +426,18 @@ final class ProperGadtConstraint private( // .ensuring(containsNoInternalTypes(_)) } + override def fullBounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = + tvarOf(p, sym) match { + case null => null + case tv => fullBounds(tv.origin) + } + + override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = + tp match { + case TypeRef(p: PathType, _) => fullBounds(p, tp.symbol) + case _ => null + } + override def bounds(sym: Symbol)(using Context): TypeBounds | Null = mapping(sym) match { case null => null @@ -209,8 +449,37 @@ final class ProperGadtConstraint private( //.ensuring(containsNoInternalTypes(_)) } + override def bounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null = + tvarOf(path, sym) match { + case null => null + case tv: TypeVar => + def retrieveBounds: TypeBounds = + bounds(tv.origin) match { + case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => + TypeAlias(reverseMapping(tpr).nn.typeRef) + case TypeAlias(tpr: TypeParamRef) if pathDepReverseMapping.contains(tpr) => + TypeAlias(pathDepReverseMapping(tpr)) + case tb => tb + } + retrieveBounds + } + + override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = + tp match { + case TypeRef(p: PathType, _) => bounds(p, tp.symbol) + case _ => null + } + override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) != null + override def contains(path: PathType)(using Context): Boolean = pathDepMapping(path) != null + + override def contains(path: PathType, sym: Symbol)(using Context): Boolean = pathDepMapping(path) match + case null => false + case innerMapping => innerMapping(sym) != null + + override def registeredTypeMembers(path: PathType): List[Symbol] = pathDepMapping(path).nn.keys + def isNarrowing: Boolean = wasConstrained override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = { @@ -233,6 +502,8 @@ final class ProperGadtConstraint private( myConstraint, mapping, reverseMapping, + pathDepMapping, + pathDepReverseMapping, wasConstrained, myScrutineePath ) @@ -242,6 +513,8 @@ final class ProperGadtConstraint private( this.myConstraint = other.myConstraint this.mapping = other.mapping this.reverseMapping = other.reverseMapping + this.pathDepMapping = other.pathDepMapping + this.pathDepReverseMapping = other.pathDepReverseMapping this.wasConstrained = other.wasConstrained this.myScrutineePath = other.myScrutineePath case _ => ; @@ -249,6 +522,8 @@ final class ProperGadtConstraint private( override def scrutineePath: TermRef | Null = myScrutineePath + override def resetScrutineePath(): Unit = myScrutineePath = null + override def withScrutineePath[T](path: TermRef)(op: => T): T = val saved = this.myScrutineePath this.myScrutineePath = path @@ -285,16 +560,29 @@ final class ProperGadtConstraint private( private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match case param: TypeParamRef => reverseMapping(param) match case sym: Symbol => sym.typeRef - case null => param + case null => pathDepReverseMapping(param) match + case tp: TypeRef => tp + case null => param case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap)) case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp) private class ExternalizeMap(using Context) extends TypeMap: def apply(tp: Type): Type = externalize(tp, this)(using mapCtx) + private def tvarOrError(sym: Symbol)(using Context): TypeVar = mapping(sym).ensuring(_ != null, i"not a constrainable symbol: $sym").uncheckedNN + private def tvarOrError(path: PathType, sym: Symbol)(using Context): TypeVar = + tvarOf(path, sym).ensuring(_ != null, i"not a constrainable type: $path.$sym").uncheckedNN + + private def tvarOrError(ntp: NamedType)(using Context): TypeVar = + tvarOf(ntp).ensuring(_ != null, i"not a constrainable type: $ntp").uncheckedNN + + // private def containsNoInternalTypes( + // tp: Type, + // acc: TypeAccumulator[Boolean] | Null = null + // )(using Context): Boolean = tp match { private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match { case tpr: TypeParamRef => !reverseMapping.contains(tpr) case tv: TypeVar => !reverseMapping.contains(tv.origin) @@ -319,6 +607,12 @@ final class ProperGadtConstraint private( mapping.foreachBinding { case (sym, _) => sb ++= i"$sym: ${fullBounds(sym)}\n" } + sb += '\n' + pathDepMapping foreachBinding { case (path, m) => + m foreachBinding { case (sym, _) => + sb ++= i"$path.$sym: ${fullBounds(TypeRef(path, sym))}\n" + } + } sb.result } } @@ -327,20 +621,36 @@ final class ProperGadtConstraint private( override def bounds(sym: Symbol)(using Context): TypeBounds | Null = null override def fullBounds(sym: Symbol)(using Context): TypeBounds | Null = null + override def bounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = null + override def fullBounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = null + override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = null + override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = null + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") + override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isNarrowing: Boolean = false override def contains(sym: Symbol)(using Context) = false + override def contains(path: PathType)(using Context) = false + + override def contains(path: PathType, symbol: Symbol)(using Context) = false + + override def registeredTypeMembers(path: PathType): List[Symbol] = Nil + override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") + override def addToConstraint(path: PathType)(using Context): Boolean = false override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") + override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") override def symbols: List[Symbol] = Nil - override def scrutineePath: TermRef | Null = unsupported("EmptyGadtConstraint.scrutineePath") + override def scrutineePath: TermRef | Null = null + + override def resetScrutineePath(): Unit = () override def withScrutineePath[T](path: TermRef)(op: => T): T = op diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index efdbc15ee803..f16564334b33 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer => * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables, * in which case the subtyping relationship "heals" the type. */ - def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace.force(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts) { + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false, typeMembersTouched: Boolean = false): Boolean = trace(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts) { def classesMayBeCompatible: Boolean = { import Flags._ @@ -119,7 +119,7 @@ trait PatternTypeConstrainer { self: TypeComparer => // consider all parents val parents = scrut.parents val andType = buildAndType(parents) - !andType.exists || constrainPatternType(pat, andType) + !andType.exists || constrainPatternType(pat, andType, typeMembersTouched = true) case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass => val patCls = pat.classSymbol // find all shared parents in the inheritance hierarchy between pat and scrut @@ -136,7 +136,7 @@ trait PatternTypeConstrainer { self: TypeComparer => val allSyms = allParentsSharedWithPat(tycon, tycon.symbol.asClass) val baseClasses = allSyms map scrut.baseType val andType = buildAndType(baseClasses) - !andType.exists || constrainPatternType(pat, andType) + !andType.exists || constrainPatternType(pat, andType, typeMembersTouched = true) case _ => def tryGadtBounds = scrut match { case scrut: TypeRef => @@ -175,25 +175,71 @@ trait PatternTypeConstrainer { self: TypeComparer => case tp => tp } - dealiasDropNonmoduleRefs(scrut) match { - case OrType(scrut1, scrut2) => - either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2)) - case AndType(scrut1, scrut2) => - constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2) - case scrut: RefinedOrRecType => - constrainPatternType(pat, stripRefinement(scrut)) - case scrut => dealiasDropNonmoduleRefs(pat) match { - case OrType(pat1, pat2) => - either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut)) - case AndType(pat1, pat2) => - constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut) - case pat: RefinedOrRecType => - constrainPatternType(stripRefinement(pat), scrut) - case pat => - tryConstrainSimplePatternType(pat, scrut) - || classesMayBeCompatible && constrainUpcasted(scrut) + def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)") { + val realScrutineePath = ctx.gadt.scrutineePath + /* We reset scrutinee path so that the path will only be used at top level. */ + ctx.gadt.resetScrutineePath() + + val scrutineePath: TermRef | SkolemType = realScrutineePath match + case null => SkolemType(scrut) + case _ => realScrutineePath + val patternPath: SkolemType = SkolemType(pat) + + val saved = state.nn.constraint + val savedGadt = ctx.gadt.fresh + + def registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) + def registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, + // so it will always be un-registered at this point + + val result = registerScrutinee && registerPattern && { + val scrutineeTypeMembers = Map.from { + ctx.gadt.registeredTypeMembers(scrutineePath) map { x => x.name -> x } + } + val patternTypeMembers = Map.from { + ctx.gadt.registeredTypeMembers(patternPath) map { x => x.name -> x } + } + + (scrutineeTypeMembers.keySet intersect patternTypeMembers.keySet) forall { name => + val scrutineeSymbol = scrutineeTypeMembers(name) + val patternSymbol = patternTypeMembers(name) + + val scrutineeType = TypeRef(scrutineePath, scrutineeSymbol) + val patternType = TypeRef(patternPath, patternSymbol) + + isSubType(scrutineeType, patternType) && isSubType(patternType, scrutineeType) + } } + + if !result then + constraint = saved + ctx.gadt.restore(savedGadt) + + result } + + def constrainTypeParams = + dealiasDropNonmoduleRefs(scrut) match { + case OrType(scrut1, scrut2) => + either(constrainPatternType(pat, scrut1, typeMembersTouched = true), constrainPatternType(pat, scrut2, typeMembersTouched = true)) + case AndType(scrut1, scrut2) => + constrainPatternType(pat, scrut1, typeMembersTouched = true) && constrainPatternType(pat, scrut2, typeMembersTouched = true) + case scrut: RefinedOrRecType => + constrainPatternType(pat, stripRefinement(scrut)) + case scrut => dealiasDropNonmoduleRefs(pat) match { + case OrType(pat1, pat2) => + either(constrainPatternType(pat1, scrut, typeMembersTouched = true), constrainPatternType(pat2, scrut, typeMembersTouched = true)) + case AndType(pat1, pat2) => + constrainPatternType(pat1, scrut, typeMembersTouched = true) && constrainPatternType(pat2, scrut, typeMembersTouched = true) + case pat: RefinedOrRecType => + constrainPatternType(stripRefinement(pat), scrut, typeMembersTouched = true) + case pat => + tryConstrainSimplePatternType(pat, scrut) + || classesMayBeCompatible && constrainUpcasted(scrut) + } + } + + constrainTypeParams && (typeMembersTouched || constrainTypeMembers) } /** Show the scrutinee. Will show the path if available. */ diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index adce363dc3f4..6d05f1d5eb7d 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -118,7 +118,22 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def isBottom(tp: Type) = tp.widen.isRef(NothingClass) protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym) - protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper) + + /** This variant of gadtBounds works for all named types. + * It queries GADT bounds for both type parameters and path-dependent types. + */ + protected def gadtBounds(tp: NamedType)(using Context) = + ctx.gadt.bounds(tp.symbol) match + case null => + tp match + case TypeRef(p: PathType, _) => ctx.gadt.bounds(p, tp.symbol) + case _ => null + case tb => tb + + protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper = isUpper) + + protected def gadtAddLowerBound(path: PathType, sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(path, sym, b, isUpper = false) + protected def gadtAddUpperBound(path: PathType, sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(path, sym, b, isUpper = true) protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying @@ -193,6 +208,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling val bounds = gadtBounds(sym) bounds != null && op(bounds) + extension (tp: NamedType) + private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean = + val bounds = gadtBounds(tp) + bounds != null && op(bounds) + private inline def comparingTypeLambdas(tl1: TypeLambda, tl2: TypeLambda)(op: => Boolean): Boolean = val saved = comparedTypeLambdas comparedTypeLambdas += tl1 @@ -540,17 +560,26 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match { case info2: TypeBounds => + /** Checks whether tp1 is registered. + * Both type parameters and path-dependent types are considered. + */ + def tpRegistered(tp: TypeRef) = ctx.gadt.contains(tp.symbol) || { + tp match + case tp @ TypeRef(p: PathType, _) => ctx.gadt.contains(p, tp.symbol) + case _ => false + } + def compareGADT: Boolean = - tp2.symbol.onGadtBounds(gbounds2 => + tp2.onGadtBounds(gbounds2 => isSubTypeWhenFrozen(tp1, gbounds2.lo) || tp1.match - case tp1: NamedType if ctx.gadt.contains(tp1.symbol) => + case tp1: TypeRef if tpRegistered(tp1) => // Note: since we approximate constrained types only with their non-param bounds, // we need to manually handle the case when we're comparing two constrained types, // one of which is constrained to be a subtype of another. // We do not need similar code in fourthTry, since we only need to care about // comparing two constrained types, and that case will be handled here first. - ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + { ctx.gadt.isLess(tp1, tp2) } && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) case _ => false || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) && (isBottom(tp1) || GADTusage(tp2.symbol)) @@ -858,7 +887,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.info match { case info1 @ TypeBounds(lo1, hi1) => def compareGADT = - tp1.symbol.onGadtBounds(gbounds1 => + tp1.onGadtBounds(gbounds1 => isSubTypeWhenFrozen(gbounds1.hi, tp2) || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) && (tp2.isAny || GADTusage(tp1.symbol)) @@ -2017,6 +2046,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => proto.isMatchedBy(tp, keepConstraint = true) } + private def rollbackGadtUnless(op: => Boolean): Boolean = + val savedGadt = ctx.gadt.fresh + val result = op + if !result then ctx.gadt.restore(savedGadt) + result + end rollbackGadtUnless + /** Narrow gadt.bounds for the type parameter referenced by `tr` to include * `bound` as an upper or lower bound (which depends on `isUpper`). * Test that the resulting bounds are still satisfiable. @@ -2024,14 +2060,31 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && { - val tparam = tr.symbol - gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") - if (bound.isRef(tparam)) false - else - val savedGadt = ctx.gadt.fresh - val success = gadtAddBound(tparam, bound, isUpper) - if !success then ctx.gadt.restore(savedGadt) - success + def narrowTypeParams = ctx.gadt.contains(tr.symbol) && { + val tparam = tr.symbol + gadts.println(i"narrow gadt bound of tparam $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") + if (bound.isRef(tparam)) false + else + rollbackGadtUnless { + if isUpper then + gadtAddUpperBound(tparam, bound) + else + gadtAddLowerBound(tparam, bound) + } + } + + def narrowPathDepType = tr match + case TypeRef(path: PathType, _) => + ctx.gadt.contains(path, tr.symbol) && { + val sym = tr.symbol + gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(sym)}") + + if (bound.isRef(sym)) false + else false + } + case _ => false + + narrowTypeParams || narrowPathDepType } } @@ -3047,6 +3100,21 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) { if (sym.exists) footprint += sym.typeRef super.gadtAddBound(sym, b, isUpper) + override def gadtBounds(tp: NamedType)(using Context): TypeBounds | Null = { + if (tp.symbol.exists) footprint += tp + super.gadtBounds(tp) + } + + override def gadtAddLowerBound(path: PathType, sym: Symbol, b: Type): Boolean = { + if (sym.exists) footprint += TypeRef(path, sym) + super.gadtAddLowerBound(path, sym, b) + } + + override def gadtAddUpperBound(path: PathType, sym: Symbol, b: Type): Boolean = { + if (sym.exists) footprint += TypeRef(path, sym) + super.gadtAddUpperBound(path, sym, b) + } + override def typeVarInstance(tvar: TypeVar)(using Context): Type = { footprint += tvar super.typeVarInstance(tvar) From 3abf4d688dbddb9ba452f436517ae77d368b9abb Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Fri, 18 Mar 2022 12:17:43 +0800 Subject: [PATCH 04/56] path-dependent GADT reasoning for type members --- .../tools/dotc/core/GadtConstraint.scala | 4 ++-- .../dotc/core/PatternTypeConstrainer.scala | 19 +++++++++++++++++-- .../dotty/tools/dotc/core/TypeComparer.scala | 5 +++-- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 0ba31ad542c2..48ffa79efae6 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -605,12 +605,12 @@ final class ProperGadtConstraint private( sb ++= constraint.show sb += '\n' mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${fullBounds(sym)}\n" + sb ++= i"$sym: ${bounds(sym)}\n" } sb += '\n' pathDepMapping foreachBinding { case (path, m) => m foreachBinding { case (sym, _) => - sb ++= i"$path.$sym: ${fullBounds(TypeRef(path, sym))}\n" + sb ++= i"$path.$sym: ${bounds(TypeRef(path, sym))}\n" } } sb.result diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index f16564334b33..80f7e51e26a4 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -175,7 +175,8 @@ trait PatternTypeConstrainer { self: TypeComparer => case tp => tp } - def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)") { + def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + import NameKinds.DepParamName val realScrutineePath = ctx.gadt.scrutineePath /* We reset scrutinee path so that the path will only be used at top level. */ ctx.gadt.resetScrutineePath() @@ -185,6 +186,9 @@ trait PatternTypeConstrainer { self: TypeComparer => case _ => realScrutineePath val patternPath: SkolemType = SkolemType(pat) + gadts.println(i"scrutinee path: $scrutineePath") + gadts.println(i"pattern path: $patternPath") + val saved = state.nn.constraint val savedGadt = ctx.gadt.fresh @@ -207,7 +211,18 @@ trait PatternTypeConstrainer { self: TypeComparer => val scrutineeType = TypeRef(scrutineePath, scrutineeSymbol) val patternType = TypeRef(patternPath, patternSymbol) - isSubType(scrutineeType, patternType) && isSubType(patternType, scrutineeType) + def constrainSP = + val res = ctx.gadt.addBound(scrutineePath, scrutineeSymbol, patternType, isUpper = true) + gadts.println(i"after $scrutineePath.$scrutineeSymbol <:< $patternType: res = $res, gadt = ${ctx.gadt.debugBoundsDescription}") + res + + def constrainPS = + val res = ctx.gadt.addBound(patternPath, patternSymbol, scrutineeType, isUpper = true) + gadts.println(i"after $patternPath.$patternSymbol <:< $scrutineePath: res = $res, gadt = ${ctx.gadt.debugBoundsDescription}") + res + + constrainPS && constrainSP + // true } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 6d05f1d5eb7d..3a119d1ae55e 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -571,7 +571,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def compareGADT: Boolean = tp2.onGadtBounds(gbounds2 => - isSubTypeWhenFrozen(tp1, gbounds2.lo) + { isSubTypeWhenFrozen(tp1, gbounds2.lo) } || tp1.match case tp1: TypeRef if tpRegistered(tp1) => // Note: since we approximate constrained types only with their non-param bounds, @@ -2080,7 +2080,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(sym)}") if (bound.isRef(sym)) false - else false + else if isUpper then gadtAddUpperBound(path, sym, bound) + else gadtAddLowerBound(path, sym, bound) } case _ => false From ff0ac7021f60ecd329f074a95f05947aa28898a6 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 20 Mar 2022 18:30:43 +0800 Subject: [PATCH 05/56] add GADT reasoning for path-dependent types --- .../tools/dotc/core/GadtConstraint.scala | 187 +++++++++++++++++- .../dotc/core/PatternTypeConstrainer.scala | 5 +- .../dotty/tools/dotc/core/TypeComparer.scala | 62 ++++-- tests/neg/pdgadt-either.scala | 17 ++ tests/pos/pdgadt-expr.scala | 13 ++ 5 files changed, 253 insertions(+), 31 deletions(-) create mode 100644 tests/neg/pdgadt-either.scala create mode 100644 tests/pos/pdgadt-expr.scala diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 48ffa79efae6..775cc44e8356 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -11,7 +11,7 @@ import collection.mutable import printing._ import scala.annotation.internal.sharable -import Denotations.Denotation +import Denotations.{Denotation, SingleDenotation} /** Types that represent a path. Can either be a TermRef or a SkolemType. */ type PathType = TermRef | SkolemType @@ -58,6 +58,12 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a path-dependent type already present in the constraint. */ def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + def addEquality(p: PathType, q: PathType): Unit + + def isEquivalent(p: PathType, q: PathType): Boolean + + def reprOf(p: PathType): PathType | Null + /** Scrutinee path of the current pattern matching. */ def scrutineePath: TermRef | Null @@ -79,6 +85,9 @@ sealed abstract class GadtConstraint extends Showable { /** Checks whether a path-dependent type is registered in the handler. */ def contains(path: PathType, sym: Symbol)(using Context): Boolean + /** Checks whether a given path-dependent type is constrainable. */ + def isConstrainablePDT(path: PathType, sym: Symbol)(using Context): Boolean + def registeredTypeMembers(path: PathType): List[Symbol] /** GADT constraint narrows bounds of at least one variable */ @@ -104,7 +113,8 @@ final class ProperGadtConstraint private( private var pathDepMapping: SimpleIdentityMap[PathType, SimpleIdentityMap[Symbol, TypeVar]], private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], private var wasConstrained: Boolean, - private var myScrutineePath: TermRef + private var myScrutineePath: TermRef, + private var myUnionFind: SimpleIdentityMap[PathType, PathType] ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -115,16 +125,112 @@ final class ProperGadtConstraint private( pathDepMapping = SimpleIdentityMap.empty, pathDepReverseMapping = SimpleIdentityMap.empty, wasConstrained = false, - myScrutineePath = null + myScrutineePath = null, + myUnionFind = SimpleIdentityMap.empty ) - /** Exposes ConstraintHandling.subsumes */ + // /** Exposes ConstraintHandling.subsumes */ + // def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { + // def extractConstraint(g: GadtConstraint) = g match { + // case s: ProperGadtConstraint => s.constraint + // case EmptyGadtConstraint => OrderingConstraint.empty + // } + // subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + // } + + /** Whether `left` subsumes `right`? + * + * `left` and `right` both stem from the constraint `pre`, with different type reasoning performed, + * during which new types might be registered in GadtConstraint. This function will take such newly + * registered types into consideration. + */ def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { + def checkSubsumes(c1: Constraint, c2: Constraint, pre: Constraint): Boolean = { + if (c2 eq pre) true + else if (c1 eq pre) false + else { + val saved = constraint + + def computeNewParams = + val params1 = c1.domainParams.toSet + val params2 = c2.domainParams.toSet + val preParams = pre.domainParams.toSet + /** Type parameter registered after branching */ + (params1.diff(preParams), params2.diff(preParams)) + + val (newParams1, newParams2) = computeNewParams + + // When new types are registered after pre, for left to subsume right, it should contain all types + // newly registered in right. + def checkNewParams: Boolean = (left, right) match { + case (left: ProperGadtConstraint, right: ProperGadtConstraint) => + newParams2 forall { p2 => + val tp2 = right.externalize(p2) + left.tvarOf(tp2) != null + } + case _ => true + } + + checkNewParams && { + // compute mappings between the newly-registered type params in the two branches + def createMappings = { + var mapping1: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + var mapping2: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + + (left, right) match { + case (left: ProperGadtConstraint, right: ProperGadtConstraint) => + newParams1 foreach { p1 => + val tp1 = left.externalize(p1) + right.tvarOf(tp1) match { + case null => + case tvar2 => + mapping1 = mapping1.updated(p1, tvar2.origin) + mapping2 = mapping2.updated(tvar2.origin, p1) + } + } + case _ => + } + + def mapTypeParam(m: SimpleIdentityMap[TypeParamRef, TypeParamRef])(tpr: TypeParamRef) = + m(tpr) match + case null => tpr + case tpr1 => tpr1 + + (mapTypeParam(mapping1), mapTypeParam(mapping2)) + } + + // bridge between the newly-registered types in c2 and c1 + val (mapping1, mapping2) = createMappings + + try { + // checks existing type parameters in `pre` + def existing: Boolean = pre.forallParams { p => + c1.contains(p) && + c2.upper(p).forall { q => + c1.isLess(p, mapping1(q)) + } && isSubTypeWhenFrozen(c1.nonParamBounds(p), c2.nonParamBounds(p)) + } + + // checks new type parameters in `c1` + def added: Boolean = newParams1 forall { p1 => + val p2 = mapping1(p1) + c2.upper(p2).forall { q => + c1.isLess(p1, mapping2(q)) + } && isSubTypeWhenFrozen(c1.nonParamBounds(p1), c2.nonParamBounds(p2)) + } + + existing && checkNewParams && added + } finally constraint = saved + } + } + } + def extractConstraint(g: GadtConstraint) = g match { case s: ProperGadtConstraint => s.constraint case EmptyGadtConstraint => OrderingConstraint.empty } - subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) + + checkSubsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) } override protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type = @@ -160,15 +266,34 @@ final class ProperGadtConstraint private( val denot1 = tp.nonPrivateMember(denot.name) val tb = denot.info - def isConstrainableAlias: Boolean = tb match + // We want to constrain type members whose bounds are type alias + // even if they are not deferred. + // + // For example: when constraining { type A } >:< { type A = Int } + // we want to take the bound (:= Int) of RHS into consideration. + def isTypeAlias: Boolean = tb match case TypeAlias(_) => true case _ => false - (denot1.symbol.is(Flags.Deferred) || isConstrainableAlias) + (denot1.symbol.is(Flags.Deferred) || isTypeAlias) && !denot1.symbol.is(Flags.Opaque) && !denot1.symbol.isClass } + private def isConstrainableTypeMember(path: PathType, sym: Symbol)(using Context): Boolean = + val mbr = path.nonPrivateMember(sym.name) + mbr.isInstanceOf[SingleDenotation] && { + val denot1 = path.nonPrivateMember(mbr.name) + val tb = mbr.info + + denot1.symbol.is(Flags.Deferred) + && !denot1.symbol.is(Flags.Opaque) + && !denot1.symbol.isClass + } + + override def isConstrainablePDT(path: PathType, sym: Symbol)(using Context): Boolean = + isConstrainablePath(path) && isConstrainableTypeMember(path, sym) + private def tvarOf(path: PathType, sym: Symbol)(using Context): TypeVar | Null = pathDepMapping(path) match case null => null @@ -185,11 +310,17 @@ final class ProperGadtConstraint private( case _ => null case tv => tv + /** Try to retrieve the internal type variable for a NamedType. */ private def tvarOf(ntp: NamedType)(using Context): TypeVar | Null = ntp match case tp: TypeRef => tvarOf(tp) case _ => null + private def tvarOf(tp: Type)(using Context): TypeVar | Null = + tp match + case tp: TypeRef => tvarOf(tp) + case _ => null + override def addToConstraint(path: PathType)(using Context): Boolean = isConstrainablePath(path) && { import NameKinds.DepParamName val pathType = path.widen @@ -280,6 +411,8 @@ final class ProperGadtConstraint private( buf ++= "}" buf.result + override def reprOf(p: PathType): PathType | Null = lookupPath(p) + override def addToConstraint(params: List[Symbol])(using Context): Boolean = { import NameKinds.DepParamName @@ -411,6 +544,32 @@ final class ProperGadtConstraint private( result } + private def lookupPath(p: PathType): PathType | Null = + def recur(p: PathType, steps: Int = 0): PathType | Null = myUnionFind(p) match + case null => null + case q if p eq q => q + case q => + if steps <= 1024 then + recur(q, steps + 1) + else + assert(false, "lookup step exceeding the threshold, possibly because of a loop in the union find") + recur(p) + + override def addEquality(p: PathType, q: PathType): Unit = + val newRep = lookupPath(p) match + case null => lookupPath(q) match + case null => p + case r => r + case r => r + + myUnionFind = myUnionFind.updated(p, newRep) + myUnionFind = myUnionFind.updated(q, newRep) + + override def isEquivalent(p: PathType, q: PathType): Boolean = + lookupPath(p) match + case null => false + case p0 => p0 eq lookupPath(q) + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) @@ -505,7 +664,8 @@ final class ProperGadtConstraint private( pathDepMapping, pathDepReverseMapping, wasConstrained, - myScrutineePath + myScrutineePath, + myUnionFind ) def restore(other: GadtConstraint): Unit = other match { @@ -517,6 +677,7 @@ final class ProperGadtConstraint private( this.pathDepReverseMapping = other.pathDepReverseMapping this.wasConstrained = other.wasConstrained this.myScrutineePath = other.myScrutineePath + this.myUnionFind = other.myUnionFind case _ => ; } @@ -626,6 +787,8 @@ final class ProperGadtConstraint private( override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = null override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = null + override def reprOf(p: PathType): PathType | Null = null + override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") @@ -637,6 +800,8 @@ final class ProperGadtConstraint private( override def contains(path: PathType, symbol: Symbol)(using Context) = false + override def isConstrainablePDT(path: PathType, symbol: Symbol)(using Context) = false + override def registeredTypeMembers(path: PathType): List[Symbol] = Nil override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint") @@ -644,7 +809,11 @@ final class ProperGadtConstraint private( override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") - override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") + override def addEquality(p: PathType, q: PathType) = () + + override def isEquivalent(p: PathType, q: PathType) = false + + override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") override def symbols: List[Symbol] = Nil diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 80f7e51e26a4..38078ca08b4c 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer => * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables, * in which case the subtyping relationship "heals" the type. */ - def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false, typeMembersTouched: Boolean = false): Boolean = trace(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts) { + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false, typeMembersTouched: Boolean = false): Boolean = trace(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { def classesMayBeCompatible: Boolean = { import Flags._ @@ -192,6 +192,8 @@ trait PatternTypeConstrainer { self: TypeComparer => val saved = state.nn.constraint val savedGadt = ctx.gadt.fresh + ctx.gadt.addEquality(scrutineePath, patternPath) + def registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) def registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, // so it will always be un-registered at this point @@ -222,7 +224,6 @@ trait PatternTypeConstrainer { self: TypeComparer => res constrainPS && constrainSP - // true } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 3a119d1ae55e..a4c2f0635885 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -570,18 +570,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } def compareGADT: Boolean = - tp2.onGadtBounds(gbounds2 => - { isSubTypeWhenFrozen(tp1, gbounds2.lo) } - || tp1.match - case tp1: TypeRef if tpRegistered(tp1) => - // Note: since we approximate constrained types only with their non-param bounds, - // we need to manually handle the case when we're comparing two constrained types, - // one of which is constrained to be a subtype of another. - // We do not need similar code in fourthTry, since we only need to care about - // comparing two constrained types, and that case will be handled here first. - { ctx.gadt.isLess(tp1, tp2) } && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) - case _ => false - || narrowGADTBounds(tp2, tp1, approx, isUpper = false)) + { tp2.onGadtBounds(gbounds2 => + { isSubTypeWhenFrozen(tp1, gbounds2.lo) } + || tp1.match + case tp1: TypeRef if tpRegistered(tp1) => + // Note: since we approximate constrained types only with their non-param bounds, + // we need to manually handle the case when we're comparing two constrained types, + // one of which is constrained to be a subtype of another. + // We do not need similar code in fourthTry, since we only need to care about + // comparing two constrained types, and that case will be handled here first. + { ctx.gadt.isLess(tp1, tp2) } && GADTusage(tp1.symbol) && GADTusage(tp2.symbol) + case _ => false) + || narrowGADTBounds(tp2, tp1, approx, isUpper = false) } && (isBottom(tp1) || GADTusage(tp2.symbol)) isSubApproxHi(tp1, info2.lo.boxedIfTypeParam(tp2.symbol)) && (trustBounds || isSubApproxHi(tp1, info2.hi)) @@ -887,10 +887,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling tp1.info match { case info1 @ TypeBounds(lo1, hi1) => def compareGADT = - tp1.onGadtBounds(gbounds1 => - isSubTypeWhenFrozen(gbounds1.hi, tp2) - || narrowGADTBounds(tp1, tp2, approx, isUpper = true)) - && (tp2.isAny || GADTusage(tp1.symbol)) + { tp1.onGadtBounds(gbounds1 => + isSubTypeWhenFrozen(gbounds1.hi, tp2)) + || narrowGADTBounds(tp1, tp2, approx, isUpper = true) + } && (tp2.isAny || GADTusage(tp1.symbol)) (!caseLambda.exists || canWidenAbstract) && isSubType(hi1.boxedIfTypeParam(tp1.symbol), tp2, approx.addLow) && (trustBounds || isSubType(lo1, tp2, approx.addLow)) @@ -1872,10 +1872,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling ctx.gadt.restore(preGadt) if op2 then if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then - gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt") + gadts.println(i"GADT CUT - prefer op2 ${ctx.gadt} over $op1Gadt") constr.println(i"CUT - prefer $constraint over $op1Constraint") else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then - gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}") + gadts.println(i"GADT CUT - prefer op1 $op1Gadt over ${ctx.gadt}") constr.println(i"CUT - prefer $op1Constraint over $constraint") constraint = op1Constraint ctx.gadt.restore(op1Gadt) @@ -2075,8 +2075,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def narrowPathDepType = tr match case TypeRef(path: PathType, _) => - ctx.gadt.contains(path, tr.symbol) && { - val sym = tr.symbol + val sym = tr.symbol + + def isConstrainable: Boolean = + ctx.gadt.contains(path, sym) || { + ctx.gadt.isConstrainablePDT(path, tr.symbol) && { + gadts.println(i"!!! registering path on the fly path=$path sym=$sym") + ctx.gadt.addToConstraint(path) && ctx.gadt.contains(path, sym) + } + } + + isConstrainable && { gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(sym)}") if (bound.isRef(sym)) false @@ -2302,6 +2311,19 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case Atoms.Range(lo2, hi2) => if hi1.subsetOf(lo2) then return tp2 if hi2.subsetOf(lo1) then return tp1 + + def getReprSet(ps: Set[Type]): Set[Type] = + ps.map { x => + x match + case p: PathType => + val rep = ctx.gadt.reprOf(p) + if rep == null then p else rep + case t => t + } + val (repLo1, repHi1, repLo2, repHi2) = (getReprSet(lo1), getReprSet(hi1), getReprSet(lo2), getReprSet(hi2)) + if repHi2.subsetOf(repLo1) then return tp1 + if repHi1.subsetOf(repLo2) then return tp2 + if (hi1 & hi2).isEmpty then return orType(tp1, tp2) case none => case none => diff --git a/tests/neg/pdgadt-either.scala b/tests/neg/pdgadt-either.scala new file mode 100644 index 000000000000..5d53f0085ae8 --- /dev/null +++ b/tests/neg/pdgadt-either.scala @@ -0,0 +1,17 @@ +trait T1 +trait T2 extends T1 +// T2 <:< T1 + +trait Expr[+X] +case class Tag1() extends Expr[T1] +case class Tag2() extends Expr[T2] + +trait TypeTag { type A } + +def foo(p: TypeTag, e: Expr[p.A]) = e match + case _: (Tag2 | Tag1) => + // Tag1: (T2 <:) T1 <: p.A + // Tag2: T2 <: p.A + // should choose T2 <: p.A + val t1: p.A = ??? : T1 // error + val t2: p.A = ??? : T2 diff --git a/tests/pos/pdgadt-expr.scala b/tests/pos/pdgadt-expr.scala new file mode 100644 index 000000000000..7bf017c8a312 --- /dev/null +++ b/tests/pos/pdgadt-expr.scala @@ -0,0 +1,13 @@ +type typed[E <: Expr, V] = E & { type T = V } + +trait Expr { type T } +case class LitInt(x: Int) extends Expr { type T = Int } +case class Add(e1: Expr typed Int, e2: Expr typed Int) extends Expr { type T = Int } +case class LitBool(x: Boolean) extends Expr { type T = Boolean } +case class Or(e1: Expr typed Boolean, e2: Expr typed Boolean) extends Expr { type T = Boolean } + +def eval(e: Expr): e.T = e match + case LitInt(x) => x + case Add(e1, e2) => eval(e1) + eval(e2) + case LitBool(b) => b + case Or(e1, e2) => eval(e1) || eval(e2) From 0c63ba0958f4317d49087036270e2fafc0c5a2e7 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 20 Mar 2022 22:21:59 +0800 Subject: [PATCH 06/56] drop typevars in type member bounds These type variables are used in the type inference of wildcard type parameters in patterns, and will be replaced by some fresh symbols soon. Keeping these type variables in the GADT bounds of path-dependent types is meaning less. So we drop them when registering the type members. --- .../tools/dotc/core/GadtConstraint.scala | 35 +++++++++++-------- tests/pos/pdgadt-wildcard.scala | 8 +++++ 2 files changed, 28 insertions(+), 15 deletions(-) create mode 100644 tests/pos/pdgadt-wildcard.scala diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 775cc44e8356..9265f2c85832 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -332,6 +332,7 @@ final class ProperGadtConstraint private( typeMembers.nonEmpty && { val typeMemberSymbols: List[Symbol] = typeMembers map { x => x.symbol } + val poly1 = PolyType(typeMembers map { d => DepParamName.fresh(d.name.toTypeName) })( pt => typeMembers map { typeMember => def substDependentSyms(tp: Type, isUpper: Boolean)(using Context): Type = { @@ -353,6 +354,7 @@ final class ProperGadtConstraint private( case tv: TypeVar => tv.origin case null => tp } + case tv: TypeVar => if isUpper then defn.AnyType else defn.NothingType case tp => tp loop(tp) @@ -381,24 +383,27 @@ final class ProperGadtConstraint private( pt => defn.AnyType ) - val tvars = typeMemberSymbols lazyZip poly1.paramRefs map { (sym, paramRef) => - val tv = TypeVar(paramRef, creatorState = null) + def register: Boolean = + val tvars = typeMemberSymbols lazyZip poly1.paramRefs map { (sym, paramRef) => + val tv = TypeVar(paramRef, creatorState = null) - val externalType = TypeRef(path, sym) - pathDepMapping = pathDepMapping.updated(path, { - val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match - case null => SimpleIdentityMap.empty - case m => m + val externalType = TypeRef(path, sym) + pathDepMapping = pathDepMapping.updated(path, { + val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match + case null => SimpleIdentityMap.empty + case m => m - old.updated(sym, tv) - }) - pathDepReverseMapping = pathDepReverseMapping.updated(tv.origin, externalType) + old.updated(sym, tv) + }) + pathDepReverseMapping = pathDepReverseMapping.updated(tv.origin, externalType) - tv - } + tv + } + + addToConstraint(poly1, tvars) + .showing(i"added to constraint: [$poly1] $path\n$debugBoundsDescription", gadts) - addToConstraint(poly1, tvars) - .showing(i"added to constraint: [$poly1] $path\n$debugBoundsDescription", gadts) + register } } @@ -406,7 +411,7 @@ final class ProperGadtConstraint private( val buf = new mutable.StringBuilder buf ++= "{\n" typeMembers foreach { denot => - buf ++= i" ${denot.symbol}: ${denot.info.bounds}\n" + buf ++= i" ${denot.symbol}: ${denot.info.bounds} ${denot.info.bounds.toString}\n" } buf ++= "}" buf.result diff --git a/tests/pos/pdgadt-wildcard.scala b/tests/pos/pdgadt-wildcard.scala new file mode 100644 index 000000000000..908ef4ff4fea --- /dev/null +++ b/tests/pos/pdgadt-wildcard.scala @@ -0,0 +1,8 @@ +trait Expr { type T } +case class Inv[X](x: X) extends Expr { type T = X } +case class Inv2[X](x: X) extends Expr { type T >: X } + +def eval(e: Expr): e.T = e match + case Inv(x) => x + case Inv2(x) => x + From f8d410cf71e1c32707edfd1d7ace0f1d6dceee14 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 21 Mar 2022 02:59:57 +0800 Subject: [PATCH 07/56] support path-dependent GADT for HKTs --- .../tools/dotc/core/GadtConstraint.scala | 4 +-- .../dotty/tools/dotc/core/TypeComparer.scala | 26 ++++++++++++++----- tests/pos/pdgadt-hkt-bounds.scala | 11 ++++++++ tests/pos/pdgadt-hkt-ordering.scala | 12 +++++++++ tests/pos/pdgadt-hkt-usage.scala | 14 ++++++++++ 5 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 tests/pos/pdgadt-hkt-bounds.scala create mode 100644 tests/pos/pdgadt-hkt-ordering.scala create mode 100644 tests/pos/pdgadt-hkt-usage.scala diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 9265f2c85832..5a3e04eb2263 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -771,12 +771,12 @@ final class ProperGadtConstraint private( sb ++= constraint.show sb += '\n' mapping.foreachBinding { case (sym, _) => - sb ++= i"$sym: ${bounds(sym)}\n" + sb ++= i"$sym: ${fullBounds(sym)}\n" } sb += '\n' pathDepMapping foreachBinding { case (path, m) => m foreachBinding { case (sym, _) => - sb ++= i"$path.$sym: ${bounds(TypeRef(path, sym))}\n" + sb ++= i"$path.$sym: ${fullBounds(TypeRef(path, sym))}\n" } } sb.result diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index a4c2f0635885..da4c9e31fe81 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1208,10 +1208,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling var touchedGADTs = false var gadtIsInstantiated = false - extension (sym: Symbol) + extension (tp: TypeRef) inline def byGadtBounds(inline op: TypeBounds => Boolean): Boolean = touchedGADTs = true - sym.onGadtBounds( + tp.onGadtBounds( b => op(b) && { gadtIsInstantiated = b.isInstanceOf[TypeAlias]; true }) def byGadtOrdering: Boolean = @@ -1219,11 +1219,19 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling && ctx.gadt.contains(tycon2sym) && ctx.gadt.isLess(tycon1sym, tycon2sym) + def byPathDepGadtOrdering: Boolean = + (tycon1, tycon2) match + case (TypeRef(p1: PathType, _), TypeRef(p2: PathType, _)) => + ctx.gadt.contains(p1, tycon1sym) + && ctx.gadt.contains(p2, tycon2sym) + && ctx.gadt.isLess(tycon1, tycon2) + case _ => false + val res = ( tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix) - || tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2)) - || tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo)) - || byGadtOrdering + || tycon1.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2)) + || tycon2.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo)) + || byGadtOrdering || byPathDepGadtOrdering ) && { // There are two cases in which we can assume injectivity. // First we check if either sym is a class. @@ -2085,10 +2093,14 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } } + def isRef: Boolean = bound match + case TypeRef(q: PathType, _) => (path eq q) && bound.isRef(sym) + case _ => false + isConstrainable && { - gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(sym)}") + gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${isRef}") - if (bound.isRef(sym)) false + if isRef then false else if isUpper then gadtAddUpperBound(path, sym, bound) else gadtAddLowerBound(path, sym, bound) } diff --git a/tests/pos/pdgadt-hkt-bounds.scala b/tests/pos/pdgadt-hkt-bounds.scala new file mode 100644 index 000000000000..4a46a617a33f --- /dev/null +++ b/tests/pos/pdgadt-hkt-bounds.scala @@ -0,0 +1,11 @@ +type Const = [X] =>> Int + +trait Expr { type F[_] } +case class ConstExprHi() extends Expr { type F[a] <: Const[a] } +case class ConstExprLo() extends Expr { type F[a] >: Const[a] } + +def foo[A](e: Expr) = e match + case _: ConstExprHi => + val i: Int = (??? : e.F[A]) : Const[A] + case _: ConstExprLo => + val i: Const[A] = ??? : Int diff --git a/tests/pos/pdgadt-hkt-ordering.scala b/tests/pos/pdgadt-hkt-ordering.scala new file mode 100644 index 000000000000..ea3879e4db14 --- /dev/null +++ b/tests/pos/pdgadt-hkt-ordering.scala @@ -0,0 +1,12 @@ +object test { + trait Tag { type F[_] } + + enum SubK[-A[_], +B[_]]: + case Refl[F[_]]() extends SubK[F, F] + + def foo(p: Tag, q: Tag, e: SubK[p.F, q.F]) = p match + case _: Tag => q match + case _: Tag => e match + case SubK.Refl() => + val t: q.F[Int] = ??? : p.F[Int] + } diff --git a/tests/pos/pdgadt-hkt-usage.scala b/tests/pos/pdgadt-hkt-usage.scala new file mode 100644 index 000000000000..97d9c6e07bc7 --- /dev/null +++ b/tests/pos/pdgadt-hkt-usage.scala @@ -0,0 +1,14 @@ +object test { + class Foo[A] + class Inv { type F[_] } + class InvFoo extends Inv { type F[x] = Foo[x] } + + object Test { + def foo(x: Inv) = x match { + case x: InvFoo => + val z1: x.F[Int] = ??? : Foo[Int] + val z2: Foo[Int] = ??? : x.F[Int] + case _ => + } + } +} From 646e2ab621df25ab912b7632c32d6cc7cb488703 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 23 Mar 2022 14:25:58 +0800 Subject: [PATCH 08/56] cleanup tracing --- .../dotc/core/PatternTypeConstrainer.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 38078ca08b4c..44c51ba5ca12 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -186,9 +186,6 @@ trait PatternTypeConstrainer { self: TypeComparer => case _ => realScrutineePath val patternPath: SkolemType = SkolemType(pat) - gadts.println(i"scrutinee path: $scrutineePath") - gadts.println(i"pattern path: $patternPath") - val saved = state.nn.constraint val savedGadt = ctx.gadt.fresh @@ -214,24 +211,22 @@ trait PatternTypeConstrainer { self: TypeComparer => val patternType = TypeRef(patternPath, patternSymbol) def constrainSP = - val res = ctx.gadt.addBound(scrutineePath, scrutineeSymbol, patternType, isUpper = true) - gadts.println(i"after $scrutineePath.$scrutineeSymbol <:< $patternType: res = $res, gadt = ${ctx.gadt.debugBoundsDescription}") - res + ctx.gadt.addBound(scrutineePath, scrutineeSymbol, patternType, isUpper = true) + .showing(i"after $scrutineePath.$scrutineeSymbol <:< $patternType: result = $result, gadt = ${ctx.gadt.debugBoundsDescription}", gadts) def constrainPS = - val res = ctx.gadt.addBound(patternPath, patternSymbol, scrutineeType, isUpper = true) - gadts.println(i"after $patternPath.$patternSymbol <:< $scrutineePath: res = $res, gadt = ${ctx.gadt.debugBoundsDescription}") - res + ctx.gadt.addBound(patternPath, patternSymbol, scrutineeType, isUpper = true) + .showing(i"after $patternPath.$patternSymbol <:< $scrutineePath: result = $result, gadt = ${ctx.gadt.debugBoundsDescription}", gadts) constrainPS && constrainSP } } - if !result then + if !res then constraint = saved ctx.gadt.restore(savedGadt) - result + res } def constrainTypeParams = From 34661c496faf7d54adc86c82a15bca21efdc5e3f Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 23 Mar 2022 14:26:12 +0800 Subject: [PATCH 09/56] filtering out NoSymbol to avoid unsound bounds --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 5a3e04eb2263..b87c050eb6bc 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -324,7 +324,7 @@ final class ProperGadtConstraint private( override def addToConstraint(path: PathType)(using Context): Boolean = isConstrainablePath(path) && { import NameKinds.DepParamName val pathType = path.widen - val typeMembers = constrainableTypeMembers(path) + val typeMembers = constrainableTypeMembers(path).filterNot(_.symbol eq NoSymbol) gadts.println(i"> trying to add $path into constraint ...") gadts.println(i" path.widen = $pathType") From 69bea1c83da1825af53db3f304d0361043f327f0 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 23 Mar 2022 14:26:28 +0800 Subject: [PATCH 10/56] return true if the type members can not be constrained --- compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 44c51ba5ca12..347992b04ce0 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -195,7 +195,7 @@ trait PatternTypeConstrainer { self: TypeComparer => def registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, // so it will always be un-registered at this point - val result = registerScrutinee && registerPattern && { + val res = !registerScrutinee || !registerPattern || { val scrutineeTypeMembers = Map.from { ctx.gadt.registeredTypeMembers(scrutineePath) map { x => x.name -> x } } From 0fe9181b022faeacf19763b2a4c94a92830c4559 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 23 Mar 2022 14:26:43 +0800 Subject: [PATCH 11/56] fix typo --- compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 347992b04ce0..d1a087f8a800 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -236,7 +236,7 @@ trait PatternTypeConstrainer { self: TypeComparer => case AndType(scrut1, scrut2) => constrainPatternType(pat, scrut1, typeMembersTouched = true) && constrainPatternType(pat, scrut2, typeMembersTouched = true) case scrut: RefinedOrRecType => - constrainPatternType(pat, stripRefinement(scrut)) + constrainPatternType(pat, stripRefinement(scrut), typeMembersTouched = true) case scrut => dealiasDropNonmoduleRefs(pat) match { case OrType(pat1, pat2) => either(constrainPatternType(pat1, scrut, typeMembersTouched = true), constrainPatternType(pat2, scrut, typeMembersTouched = true)) From d3eade0cb5ed578e0097a11e44c57a531ebacdc5 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 23 Mar 2022 14:26:53 +0800 Subject: [PATCH 12/56] remove limitation errors in structural gadt neg tests These limitation is solved by path-dependent GADT now. --- tests/neg/structural-gadt.scala | 8 ++++---- .../neg/structural-recursive-both1-gadt.scala | 18 +++++++++--------- .../neg/structural-recursive-both2-gadt.scala | 18 +++++++++--------- .../structural-recursive-pattern-gadt.scala | 10 +++++----- .../structural-recursive-scrutinee-gadt.scala | 16 ++++++++-------- 5 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/neg/structural-gadt.scala b/tests/neg/structural-gadt.scala index 9a14881b5804..91355cbcc025 100644 --- a/tests/neg/structural-gadt.scala +++ b/tests/neg/structural-gadt.scala @@ -18,11 +18,11 @@ object Test { val i: Int = ??? : A // limitation // error case _: IntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // limitation // error case _: Expr { type T = Int } => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // limitation // error } @@ -36,11 +36,11 @@ object Test { val i: Int = ??? : A // error case _: IntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: Expr { type T = Int } => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } diff --git a/tests/neg/structural-recursive-both1-gadt.scala b/tests/neg/structural-recursive-both1-gadt.scala index 97df59a92bb5..db086c83e13c 100644 --- a/tests/neg/structural-recursive-both1-gadt.scala +++ b/tests/neg/structural-recursive-both1-gadt.scala @@ -28,23 +28,23 @@ object Test { def foo[A](e: IndirectExprExact[A]) = e match { case _: AltIndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: AltIndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: AltIndirectExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: IndirectExprSub[A]) = e match { @@ -83,11 +83,11 @@ object Test { val i: Int = ??? : A // error case _: AltIndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: AltIndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-both2-gadt.scala b/tests/neg/structural-recursive-both2-gadt.scala index b58e05f3ed43..883093f52896 100644 --- a/tests/neg/structural-recursive-both2-gadt.scala +++ b/tests/neg/structural-recursive-both2-gadt.scala @@ -28,23 +28,23 @@ object Test { def foo[A](e: AltIndirectExprExact[A]) = e match { case _: IndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: AltIndirectExprSub[A]) = e match { @@ -83,11 +83,11 @@ object Test { val i: Int = ??? : A // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-pattern-gadt.scala b/tests/neg/structural-recursive-pattern-gadt.scala index ea7394b5b66b..e2d3b42eeb37 100644 --- a/tests/neg/structural-recursive-pattern-gadt.scala +++ b/tests/neg/structural-recursive-pattern-gadt.scala @@ -39,11 +39,11 @@ object Test { val i: Int = ??? : A // limitation // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // limitation // error } @@ -61,11 +61,11 @@ object Test { val i: Int = ??? : A // error case _: IndirectIntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: IndirectExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } diff --git a/tests/neg/structural-recursive-scrutinee-gadt.scala b/tests/neg/structural-recursive-scrutinee-gadt.scala index cd4e2376f49a..0936de716ede 100644 --- a/tests/neg/structural-recursive-scrutinee-gadt.scala +++ b/tests/neg/structural-recursive-scrutinee-gadt.scala @@ -28,19 +28,19 @@ object Test { def foo[A](e: IndirectExprExact[A]) = e match { case _: IntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: ExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IntExpr => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A case _: ExprExact[Int] => - val a: A = 0 // limitation // error - val i: Int = ??? : A // limitation // error + val a: A = 0 + val i: Int = ??? : A } def bar[A](e: IndirectExprSub[A]) = e match { @@ -71,11 +71,11 @@ object Test { val i: Int = ??? : A // error case _: IntExpr => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error case _: ExprExact[Int] => - val a: A = 0 // limitation // error + val a: A = 0 val i: Int = ??? : A // error } } From c3795fe50001393bcf008ad99c4c274a38713c84 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 23 Mar 2022 15:25:16 +0800 Subject: [PATCH 13/56] fix code to pass explicit null check --- .../tools/dotc/core/GadtConstraint.scala | 36 ++++++++++--------- .../dotc/core/PatternTypeConstrainer.scala | 7 ++-- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index b87c050eb6bc..97f309a64dc1 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -71,7 +71,7 @@ sealed abstract class GadtConstraint extends Showable { def resetScrutineePath(): Unit /** Set the scrutinee path. */ - def withScrutineePath[T](path: TermRef)(op: => T): T + def withScrutineePath[T](path: TermRef | Null)(op: => T): T /** Is the symbol registered in the constraint? * @@ -113,7 +113,7 @@ final class ProperGadtConstraint private( private var pathDepMapping: SimpleIdentityMap[PathType, SimpleIdentityMap[Symbol, TypeVar]], private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], private var wasConstrained: Boolean, - private var myScrutineePath: TermRef, + private var myScrutineePath: TermRef | Null, private var myUnionFind: SimpleIdentityMap[PathType, PathType] ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -183,7 +183,7 @@ final class ProperGadtConstraint private( val tp1 = left.externalize(p1) right.tvarOf(tp1) match { case null => - case tvar2 => + case tvar2: TypeVar => mapping1 = mapping1.updated(p1, tvar2.origin) mapping2 = mapping2.updated(tvar2.origin, p1) } @@ -194,7 +194,7 @@ final class ProperGadtConstraint private( def mapTypeParam(m: SimpleIdentityMap[TypeParamRef, TypeParamRef])(tpr: TypeParamRef) = m(tpr) match case null => tpr - case tpr1 => tpr1 + case tpr1: TypeParamRef => tpr1 (mapTypeParam(mapping1), mapTypeParam(mapping2)) } @@ -297,7 +297,7 @@ final class ProperGadtConstraint private( private def tvarOf(path: PathType, sym: Symbol)(using Context): TypeVar | Null = pathDepMapping(path) match case null => null - case innerMapping => innerMapping(sym) + case innerMapping => innerMapping.nn(sym) /** Try to retrieve type variable for some TypeRef. * Both type parameters and path-dependent types are considered. @@ -391,7 +391,7 @@ final class ProperGadtConstraint private( pathDepMapping = pathDepMapping.updated(path, { val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match case null => SimpleIdentityMap.empty - case m => m + case m => m.nn old.updated(sym, tv) }) @@ -552,8 +552,8 @@ final class ProperGadtConstraint private( private def lookupPath(p: PathType): PathType | Null = def recur(p: PathType, steps: Int = 0): PathType | Null = myUnionFind(p) match case null => null - case q if p eq q => q - case q => + case q: PathType if q eq p => q + case q: PathType => if steps <= 1024 then recur(q, steps + 1) else @@ -561,11 +561,11 @@ final class ProperGadtConstraint private( recur(p) override def addEquality(p: PathType, q: PathType): Unit = - val newRep = lookupPath(p) match + val newRep: PathType = lookupPath(p) match case null => lookupPath(q) match case null => p - case r => r - case r => r + case r: PathType => r + case r: PathType => r myUnionFind = myUnionFind.updated(p, newRep) myUnionFind = myUnionFind.updated(q, newRep) @@ -573,7 +573,9 @@ final class ProperGadtConstraint private( override def isEquivalent(p: PathType, q: PathType): Boolean = lookupPath(p) match case null => false - case p0 => p0 eq lookupPath(q) + case p0: PathType => lookupPath(q) match + case null => false + case q0: PathType => p0 eq q0 override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = constraint.isLess(tvarOrError(sym1).origin, tvarOrError(sym2).origin) @@ -593,7 +595,7 @@ final class ProperGadtConstraint private( override def fullBounds(p: PathType, sym: Symbol)(using Context): TypeBounds | Null = tvarOf(p, sym) match { case null => null - case tv => fullBounds(tv.origin) + case tv => fullBounds(tv.nn.origin) } override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = @@ -622,7 +624,7 @@ final class ProperGadtConstraint private( case TypeAlias(tpr: TypeParamRef) if reverseMapping.contains(tpr) => TypeAlias(reverseMapping(tpr).nn.typeRef) case TypeAlias(tpr: TypeParamRef) if pathDepReverseMapping.contains(tpr) => - TypeAlias(pathDepReverseMapping(tpr)) + TypeAlias(pathDepReverseMapping(tpr).nn) case tb => tb } retrieveBounds @@ -640,7 +642,7 @@ final class ProperGadtConstraint private( override def contains(path: PathType, sym: Symbol)(using Context): Boolean = pathDepMapping(path) match case null => false - case innerMapping => innerMapping(sym) != null + case innerMapping => innerMapping.nn(sym) != null override def registeredTypeMembers(path: PathType): List[Symbol] = pathDepMapping(path).nn.keys @@ -690,7 +692,7 @@ final class ProperGadtConstraint private( override def resetScrutineePath(): Unit = myScrutineePath = null - override def withScrutineePath[T](path: TermRef)(op: => T): T = + override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = val saved = this.myScrutineePath this.myScrutineePath = path val result = op @@ -826,7 +828,7 @@ final class ProperGadtConstraint private( override def resetScrutineePath(): Unit = () - override def withScrutineePath[T](path: TermRef)(op: => T): T = op + override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = op override def fresh = new ProperGadtConstraint override def restore(other: GadtConstraint): Unit = diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index d1a087f8a800..7fa6bf1cef03 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -255,10 +255,9 @@ trait PatternTypeConstrainer { self: TypeComparer => /** Show the scrutinee. Will show the path if available. */ private def scrutRepr(scrut: Type): String = - if ctx.gadt.scrutineePath != null then - ctx.gadt.scrutineePath.show - else - scrut.show + ctx.gadt.scrutineePath match + case null => scrut.show + case p: PathType => p.show /** Constrain "simple" patterns (see `constrainPatternType`). * From 284ccf1137517d3d0d05692c367463c9e58c4830 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Fri, 15 Jul 2022 16:57:39 +0800 Subject: [PATCH 14/56] format and cleanup --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 97f309a64dc1..259125f50d6d 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -728,9 +728,10 @@ final class ProperGadtConstraint private( private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match case param: TypeParamRef => reverseMapping(param) match case sym: Symbol => sym.typeRef - case null => pathDepReverseMapping(param) match - case tp: TypeRef => tp - case null => param + case null => + pathDepReverseMapping(param) match + case tp: TypeRef => tp + case null => param case tp: TypeAlias => tp.derivedAlias(externalize(tp.alias, theMap)) case tp => (if theMap == null then ExternalizeMap() else theMap).mapOver(tp) @@ -747,10 +748,6 @@ final class ProperGadtConstraint private( private def tvarOrError(ntp: NamedType)(using Context): TypeVar = tvarOf(ntp).ensuring(_ != null, i"not a constrainable type: $ntp").uncheckedNN - // private def containsNoInternalTypes( - // tp: Type, - // acc: TypeAccumulator[Boolean] | Null = null - // )(using Context): Boolean = tp match { private def containsNoInternalTypes(tp: Type, theAcc: TypeAccumulator[Boolean] | Null = null)(using Context): Boolean = tp match { case tpr: TypeParamRef => !reverseMapping.contains(tpr) case tv: TypeVar => !reverseMapping.contains(tv.origin) From 991c4ccd65f4bff5ec88fa25a8ad35081b9856a0 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Fri, 15 Jul 2022 17:02:29 +0800 Subject: [PATCH 15/56] remove limitation errors in testcases Thanks to the path-dependent GADT reasoning logic, these limitation errors can be dropped from the testcases. See #14754. --- tests/neg/structural-gadt.scala | 14 ++++++-------- tests/neg/structural-recursive-both1-gadt.scala | 4 ++-- tests/neg/structural-recursive-both2-gadt.scala | 4 ++-- tests/neg/structural-recursive-pattern-gadt.scala | 14 ++++++-------- .../neg/structural-recursive-scrutinee-gadt.scala | 4 ++-- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/tests/neg/structural-gadt.scala b/tests/neg/structural-gadt.scala index 91355cbcc025..c970f08709de 100644 --- a/tests/neg/structural-gadt.scala +++ b/tests/neg/structural-gadt.scala @@ -1,7 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. -// -// Lines with "// limitation" are the ones that we could soundly allow. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. object Test { trait Expr { type T } @@ -11,19 +9,19 @@ object Test { def foo[A](e: Expr { type T = A }) = e match { case _: IntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: Expr { type T <: Int } => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IntExpr => val a: A = 0 - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: Expr { type T = Int } => val a: A = 0 - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A } def bar[A](e: Expr { type T <: A }) = e match { diff --git a/tests/neg/structural-recursive-both1-gadt.scala b/tests/neg/structural-recursive-both1-gadt.scala index db086c83e13c..ae17f01ff12a 100644 --- a/tests/neg/structural-recursive-both1-gadt.scala +++ b/tests/neg/structural-recursive-both1-gadt.scala @@ -1,5 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. // // Lines with "// limitation" are the ones that we could soundly allow. object Test { diff --git a/tests/neg/structural-recursive-both2-gadt.scala b/tests/neg/structural-recursive-both2-gadt.scala index 883093f52896..08f93aa8469c 100644 --- a/tests/neg/structural-recursive-both2-gadt.scala +++ b/tests/neg/structural-recursive-both2-gadt.scala @@ -1,5 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. // // Lines with "// limitation" are the ones that we could soundly allow. object Test { diff --git a/tests/neg/structural-recursive-pattern-gadt.scala b/tests/neg/structural-recursive-pattern-gadt.scala index e2d3b42eeb37..a66f2b6f43f8 100644 --- a/tests/neg/structural-recursive-pattern-gadt.scala +++ b/tests/neg/structural-recursive-pattern-gadt.scala @@ -1,7 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. -// -// Lines with "// limitation" are the ones that we could soundly allow. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. object Test { trait Expr { type T } @@ -28,15 +26,15 @@ object Test { def foo[A](e: ExprExact[A]) = e match { case _: IndirectIntLit => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectExprSub2[Int] => val a: A = 0 // error - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A case _: IndirectIntExpr => val a: A = 0 @@ -44,7 +42,7 @@ object Test { case _: IndirectExprExact[Int] => val a: A = 0 - val i: Int = ??? : A // limitation // error + val i: Int = ??? : A } def bar[A](e: ExprSub[A]) = e match { diff --git a/tests/neg/structural-recursive-scrutinee-gadt.scala b/tests/neg/structural-recursive-scrutinee-gadt.scala index 0936de716ede..f5f8c29cc314 100644 --- a/tests/neg/structural-recursive-scrutinee-gadt.scala +++ b/tests/neg/structural-recursive-scrutinee-gadt.scala @@ -1,5 +1,5 @@ -// This file is part of tests for inferring GADT constraints from type members, -// which needed to be reverted because of soundness issues. +// This file is part of tests for inferring GADT constraints from type members. +// They are now supported by path-dependent GADT reasoning. See #14754. // // Lines with "// limitation" are the ones that we could soundly allow. object Test { From 9862eb2c658ab25a885b5a2a23c281462bd0d803 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Fri, 15 Jul 2022 21:03:28 +0800 Subject: [PATCH 16/56] record pattern path --- .../tools/dotc/core/GadtConstraint.scala | 42 +++++++++++++++++-- .../dotc/core/PatternTypeConstrainer.scala | 2 +- .../src/dotty/tools/dotc/typer/Typer.scala | 4 ++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 259125f50d6d..7c433a3a6cc8 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -73,6 +73,12 @@ sealed abstract class GadtConstraint extends Showable { /** Set the scrutinee path. */ def withScrutineePath[T](path: TermRef | Null)(op: => T): T + /** Supply the real pattern path. */ + def supplyPatternPath(path: TermRef)(using Context): Unit + + /** Create a skolem type for pattern. */ + def createPatternSkolem(pat: Type): SkolemType + /** Is the symbol registered in the constraint? * * @note this is true even if the symbol is constrained to be equal to another type, unlike [[Constraint.contains]]. @@ -114,7 +120,8 @@ final class ProperGadtConstraint private( private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], private var wasConstrained: Boolean, private var myScrutineePath: TermRef | Null, - private var myUnionFind: SimpleIdentityMap[PathType, PathType] + private var myUnionFind: SimpleIdentityMap[PathType, PathType], + private var myPatternSkolem: SkolemType | Null, ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -126,7 +133,8 @@ final class ProperGadtConstraint private( pathDepReverseMapping = SimpleIdentityMap.empty, wasConstrained = false, myScrutineePath = null, - myUnionFind = SimpleIdentityMap.empty + myUnionFind = SimpleIdentityMap.empty, + myPatternSkolem = null, ) // /** Exposes ConstraintHandling.subsumes */ @@ -672,7 +680,8 @@ final class ProperGadtConstraint private( pathDepReverseMapping, wasConstrained, myScrutineePath, - myUnionFind + myUnionFind, + myPatternSkolem, ) def restore(other: GadtConstraint): Unit = other match { @@ -685,6 +694,7 @@ final class ProperGadtConstraint private( this.wasConstrained = other.wasConstrained this.myScrutineePath = other.myScrutineePath this.myUnionFind = other.myUnionFind + this.myPatternSkolem = other.myPatternSkolem case _ => ; } @@ -692,10 +702,32 @@ final class ProperGadtConstraint private( override def resetScrutineePath(): Unit = myScrutineePath = null + override def supplyPatternPath(path: TermRef)(using Context): Unit = + if myPatternSkolem eq null then + () + else + pathDepMapping(myPatternSkolem.nn) match { + case null => + case m: SimpleIdentityMap[Symbol, TypeVar] => + pathDepMapping = pathDepMapping.updated(path, m) + m foreachBinding { (sym, tvar) => + val tpr = tvar.origin + pathDepReverseMapping = pathDepReverseMapping.updated(tpr, TypeRef(path, sym)) + } + } + end supplyPatternPath + + override def createPatternSkolem(pat: Type): SkolemType = + myPatternSkolem = SkolemType(pat) + myPatternSkolem.nn + end createPatternSkolem + override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = val saved = this.myScrutineePath this.myScrutineePath = path + val result = op + this.myScrutineePath = saved result @@ -827,6 +859,10 @@ final class ProperGadtConstraint private( override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = op + override def supplyPatternPath(path: TermRef)(using Context): Unit = () + + override def createPatternSkolem(pat: Type): SkolemType = unsupported("EmptyGadtConstraint.createPatternSkolem") + override def fresh = new ProperGadtConstraint override def restore(other: GadtConstraint): Unit = assert(!other.isNarrowing, "cannot restore a non-empty GADTMap") diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 7fa6bf1cef03..0365d6b2d6fc 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -184,7 +184,7 @@ trait PatternTypeConstrainer { self: TypeComparer => val scrutineePath: TermRef | SkolemType = realScrutineePath match case null => SkolemType(scrut) case _ => realScrutineePath - val patternPath: SkolemType = SkolemType(pat) + val patternPath: SkolemType = ctx.gadt.createPatternSkolem(pat) val saved = state.nn.constraint val savedGadt = ctx.gadt.fresh diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 58aa36ebf518..6175d1b66d80 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1765,6 +1765,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val pat1 = gadtCtx.gadt.withScrutineePath(scrutineePath) { typedPattern(tree.pat, wideSelType)(using gadtCtx) } + + if scrutineePath.ne(null) && pat1.symbol.isPatternBound then + gadtCtx.gadt.supplyPatternPath(pat1.symbol.termRef) + caseRest(pat1)( using Nullables.caseContext(sel, pat1)( using gadtCtx)) From 71dbe8c0b8438fbb0d347d46fe6ba08189a4d6dc Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 20 Jul 2022 21:43:40 +0800 Subject: [PATCH 17/56] changing externalize to protected --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 7c433a3a6cc8..c67fbdc38836 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -757,7 +757,12 @@ final class ProperGadtConstraint private( // ---- Private ---------------------------------------------------------- - private def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match + /** Externalize the internal TypeParamRefs in the given type. + * + * We declare the method as `protected` instead of `private`, because declaring it as + * private will break a pickling test. This is a temporary workaround. + */ + protected def externalize(tp: Type, theMap: TypeMap | Null = null)(using Context): Type = tp match case param: TypeParamRef => reverseMapping(param) match case sym: Symbol => sym.typeRef case null => From 1177cf63678659714109c24d11b851c15e6b990c Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Fri, 22 Jul 2022 17:01:08 +0800 Subject: [PATCH 18/56] remove workaround for type variable --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 1 - tests/{pos => neg}/pdgadt-wildcard.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) rename tests/{pos => neg}/pdgadt-wildcard.scala (80%) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index c67fbdc38836..0ee33fb36665 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -362,7 +362,6 @@ final class ProperGadtConstraint private( case tv: TypeVar => tv.origin case null => tp } - case tv: TypeVar => if isUpper then defn.AnyType else defn.NothingType case tp => tp loop(tp) diff --git a/tests/pos/pdgadt-wildcard.scala b/tests/neg/pdgadt-wildcard.scala similarity index 80% rename from tests/pos/pdgadt-wildcard.scala rename to tests/neg/pdgadt-wildcard.scala index 908ef4ff4fea..a9ef72023e52 100644 --- a/tests/pos/pdgadt-wildcard.scala +++ b/tests/neg/pdgadt-wildcard.scala @@ -3,6 +3,6 @@ case class Inv[X](x: X) extends Expr { type T = X } case class Inv2[X](x: X) extends Expr { type T >: X } def eval(e: Expr): e.T = e match - case Inv(x) => x + case Inv(x) => x // limitation // error case Inv2(x) => x From 198287e34a142c018a1169c7382e632c10b1484d Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sat, 23 Jul 2022 00:12:56 +0800 Subject: [PATCH 19/56] only constrain non-private type members --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 0ee33fb36665..770a5fb962b0 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -12,6 +12,7 @@ import printing._ import scala.annotation.internal.sharable import Denotations.{Denotation, SingleDenotation} +import SymDenotations.NoDenotation /** Types that represent a path. Can either be a TermRef or a SkolemType. */ type PathType = TermRef | SkolemType @@ -283,9 +284,12 @@ final class ProperGadtConstraint private( case TypeAlias(_) => true case _ => false + def nonPrivate: Boolean = !denot1.isInstanceOf[NoDenotation.type] + (denot1.symbol.is(Flags.Deferred) || isTypeAlias) && !denot1.symbol.is(Flags.Opaque) && !denot1.symbol.isClass + && nonPrivate } private def isConstrainableTypeMember(path: PathType, sym: Symbol)(using Context): Boolean = From 4eac9a4e77c3898b421abfca5f65eeba311885ff Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sat, 23 Jul 2022 16:03:36 +0800 Subject: [PATCH 20/56] avoid stripping type variables in OrderingConstraint.replace --- .../tools/dotc/core/OrderingConstraint.scala | 2 +- .../dotty/tools/dotc/typer/Inferencing.scala | 3 ++- tests/neg/pdgadt-wildcard.scala | 8 -------- tests/pos/pdgadt-wildcard.scala | 20 +++++++++++++++++++ 4 files changed, 23 insertions(+), 10 deletions(-) delete mode 100644 tests/neg/pdgadt-wildcard.scala create mode 100644 tests/pos/pdgadt-wildcard.scala diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index 1341fac7d735..b1e52511e6e1 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -454,7 +454,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds, * of the parameter elsewhere in the constraint by type `tp`. */ def replace(param: TypeParamRef, tp: Type)(using Context): OrderingConstraint = - val replacement = tp.dealiasKeepAnnots.stripTypeVar + val replacement = tp.dealiasKeepAnnots if param == replacement then this.checkNonCyclic() else assert(replacement.isValueTypeOrLambda) diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index 27b83e025cf9..a072e4e5c897 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -437,7 +437,8 @@ object Inferencing { } // We add the created symbols to GADT constraint here. - if (res.nonEmpty) ctx.gadt.addToConstraint(res) + if res.nonEmpty then ctx.gadt.addToConstraint(res) + res } diff --git a/tests/neg/pdgadt-wildcard.scala b/tests/neg/pdgadt-wildcard.scala deleted file mode 100644 index a9ef72023e52..000000000000 --- a/tests/neg/pdgadt-wildcard.scala +++ /dev/null @@ -1,8 +0,0 @@ -trait Expr { type T } -case class Inv[X](x: X) extends Expr { type T = X } -case class Inv2[X](x: X) extends Expr { type T >: X } - -def eval(e: Expr): e.T = e match - case Inv(x) => x // limitation // error - case Inv2(x) => x - diff --git a/tests/pos/pdgadt-wildcard.scala b/tests/pos/pdgadt-wildcard.scala new file mode 100644 index 000000000000..d7fc9511fdb3 --- /dev/null +++ b/tests/pos/pdgadt-wildcard.scala @@ -0,0 +1,20 @@ +trait Expr { type T } +case class Inv[X](x: X) extends Expr { type T = X } +case class Inv2[X](x: X) extends Expr { type T >: X } + +def eval(e: Expr): e.T = e match + case Inv(x) => x + case Inv2(x) => x + +trait Foo[-T] +case class Bar() extends Foo[Int] + +def foo(e1: Expr, e2: Foo[e1.T]) = e1 match { + case Inv2(x) => e2 match { + case Bar() => + val t0: Int = x + val t1: e1.T = x + val t2: Int = t1 + } +} + From 8972fe9dd9d654b4f55d24c7dfda9dd32f6773e4 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sat, 23 Jul 2022 19:07:58 +0800 Subject: [PATCH 21/56] more path-dependent GADT examples --- tests/pos/pdgadt-asmember.scala | 14 ++++++++++ tests/pos/pdgadt-nat-simpleadd.scala | 40 ++++++++++++++++++++++++++++ tests/pos/pdgadt-tupof.scala | 37 +++++++++++++++++++++++++ 3 files changed, 91 insertions(+) create mode 100644 tests/pos/pdgadt-asmember.scala create mode 100644 tests/pos/pdgadt-nat-simpleadd.scala create mode 100644 tests/pos/pdgadt-tupof.scala diff --git a/tests/pos/pdgadt-asmember.scala b/tests/pos/pdgadt-asmember.scala new file mode 100644 index 000000000000..6a2157a0abdb --- /dev/null +++ b/tests/pos/pdgadt-asmember.scala @@ -0,0 +1,14 @@ +// Taken from https://github.com/lampepfl/dotty/pull/14754#issuecomment-1157427912. +trait T[X] +case object Foo extends T[Unit] + +trait AsMember { + type L + val tl: T[L] +} + +def testMember(am: AsMember): Unit = + am.tl match { + case Foo => println(summon[am.L =:= Unit]) + case _ => () + } diff --git a/tests/pos/pdgadt-nat-simpleadd.scala b/tests/pos/pdgadt-nat-simpleadd.scala new file mode 100644 index 000000000000..f87e5a644767 --- /dev/null +++ b/tests/pos/pdgadt-nat-simpleadd.scala @@ -0,0 +1,40 @@ +trait Z +trait S[N] + +type of[A, B] = A & { type T = B } + +trait Nat { type T } +case class Zero() extends Nat { type T = Z } +case class Succ[P](prec: Nat of P) extends Nat { type T = S[P] } + + +type :=:[T, X] = T & { type R = X } + +trait :+:[A, B] { + type R +} + +object Addition { + given AddZero[N]: :+:[Z, N] with { + type R = N + } + + given AddSucc[M, N, X](using e: (M :+: N) :=: X): :+:[S[M], N] with { + val e0: (M :+: N) :=: X = e + type R = S[X] + } +} + +object Proof { + import Addition.given + + def zeroAddN[N]: (Z :+: N) :=: N = summon + + def nAddZero[N](n: Nat of N): (N :+: Z) :=: N = n match { + case Zero() => zeroAddN[Z] + case p: Succ[pn] => + val e0: (pn :+: Z) :=: pn = nAddZero[pn](p.prec) + AddSucc(using e0) + } +} + diff --git a/tests/pos/pdgadt-tupof.scala b/tests/pos/pdgadt-tupof.scala new file mode 100644 index 000000000000..c4651526448e --- /dev/null +++ b/tests/pos/pdgadt-tupof.scala @@ -0,0 +1,37 @@ +import scala.compiletime.ops.int.S + +type sized[T, N <: Int] = T & { type Size = N } + +abstract class TupOf[T, +A] { + type Size <: Int +} + +object TupOf { + given Empty[A]: TupOf[EmptyTuple, A] with { + type Size = 0 + } + + final given Cons[A, T <: Tuple, N <: Int](using p: T TupOf A sized N): TupOf[A *: T, A] with { + val p0: T TupOf A sized N = p + type Size = S[N] + } +} + +enum Vec[N <: Int, +A]: + case VecNil extends Vec[0, Nothing] + case VecCons[N0 <: Int, A](head: A, tail: Vec[N0, A]) extends Vec[S[N0], A] + +object Vec { + import TupOf._ + def apply[A, T <: Tuple](xs: T)(using p: T TupOf A): Vec[p.Size, A] = p match { + case _: TupOf.Empty[A] => VecNil + case p1: TupOf.Cons[a, t, n] => + VecCons(xs.head, apply(xs.tail)(using p1.p0)) + } + + def main(): Unit = { + val vec1: Vec[3, Int] = VecCons(1, VecCons(2, VecCons(3, VecNil))) + val vec2: Vec[3, Int] = Vec(1, 2, 3) + } +} + From ab7bb6c9c33863ac6409168c219015a121097ff5 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 24 Jul 2022 17:41:41 +0800 Subject: [PATCH 22/56] remove test since it fails -Ycheck This test triggers a bug in the compiler unrelated to path-dependent GADT reasoning. See #15743. We should add this test back after fixing this issue. --- tests/pos/pdgadt-tupof.scala | 37 ------------------------------------ 1 file changed, 37 deletions(-) delete mode 100644 tests/pos/pdgadt-tupof.scala diff --git a/tests/pos/pdgadt-tupof.scala b/tests/pos/pdgadt-tupof.scala deleted file mode 100644 index c4651526448e..000000000000 --- a/tests/pos/pdgadt-tupof.scala +++ /dev/null @@ -1,37 +0,0 @@ -import scala.compiletime.ops.int.S - -type sized[T, N <: Int] = T & { type Size = N } - -abstract class TupOf[T, +A] { - type Size <: Int -} - -object TupOf { - given Empty[A]: TupOf[EmptyTuple, A] with { - type Size = 0 - } - - final given Cons[A, T <: Tuple, N <: Int](using p: T TupOf A sized N): TupOf[A *: T, A] with { - val p0: T TupOf A sized N = p - type Size = S[N] - } -} - -enum Vec[N <: Int, +A]: - case VecNil extends Vec[0, Nothing] - case VecCons[N0 <: Int, A](head: A, tail: Vec[N0, A]) extends Vec[S[N0], A] - -object Vec { - import TupOf._ - def apply[A, T <: Tuple](xs: T)(using p: T TupOf A): Vec[p.Size, A] = p match { - case _: TupOf.Empty[A] => VecNil - case p1: TupOf.Cons[a, t, n] => - VecCons(xs.head, apply(xs.tail)(using p1.p0)) - } - - def main(): Unit = { - val vec1: Vec[3, Int] = VecCons(1, VecCons(2, VecCons(3, VecNil))) - val vec2: Vec[3, Int] = Vec(1, 2, 3) - } -} - From 5c57baddc26a15012d998e94c7a4aac446a1e1e7 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Tue, 9 Aug 2022 16:46:34 +0800 Subject: [PATCH 23/56] also rollback path-dependent GADT constraints on failure --- .../src/dotty/tools/dotc/core/TypeComparer.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index da4c9e31fe81..6792e6adf130 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2097,13 +2097,16 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case TypeRef(q: PathType, _) => (path eq q) && bound.isRef(sym) case _ => false - isConstrainable && { - gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${isRef}") + rollbackGadtUnless { + isConstrainable && { + gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${isRef}") - if isRef then false - else if isUpper then gadtAddUpperBound(path, sym, bound) - else gadtAddLowerBound(path, sym, bound) + if isRef then false + else if isUpper then gadtAddUpperBound(path, sym, bound) + else gadtAddLowerBound(path, sym, bound) + } } + case _ => false narrowTypeParams || narrowPathDepType From d092969562e0ca40e78f67fcf1cfb48652c4a068 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Tue, 9 Aug 2022 17:10:19 +0800 Subject: [PATCH 24/56] Revert "avoid stripping type variables in OrderingConstraint.replace" This reverts commit 02a5369e7f0aba916f209994137f77c31eae251e. --- .../tools/dotc/core/OrderingConstraint.scala | 2 +- .../dotty/tools/dotc/typer/Inferencing.scala | 3 +-- tests/neg/pdgadt-wildcard.scala | 8 ++++++++ tests/pos/pdgadt-wildcard.scala | 20 ------------------- 4 files changed, 10 insertions(+), 23 deletions(-) create mode 100644 tests/neg/pdgadt-wildcard.scala delete mode 100644 tests/pos/pdgadt-wildcard.scala diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index b1e52511e6e1..1341fac7d735 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -454,7 +454,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds, * of the parameter elsewhere in the constraint by type `tp`. */ def replace(param: TypeParamRef, tp: Type)(using Context): OrderingConstraint = - val replacement = tp.dealiasKeepAnnots + val replacement = tp.dealiasKeepAnnots.stripTypeVar if param == replacement then this.checkNonCyclic() else assert(replacement.isValueTypeOrLambda) diff --git a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala index a072e4e5c897..27b83e025cf9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inferencing.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inferencing.scala @@ -437,8 +437,7 @@ object Inferencing { } // We add the created symbols to GADT constraint here. - if res.nonEmpty then ctx.gadt.addToConstraint(res) - + if (res.nonEmpty) ctx.gadt.addToConstraint(res) res } diff --git a/tests/neg/pdgadt-wildcard.scala b/tests/neg/pdgadt-wildcard.scala new file mode 100644 index 000000000000..a9ef72023e52 --- /dev/null +++ b/tests/neg/pdgadt-wildcard.scala @@ -0,0 +1,8 @@ +trait Expr { type T } +case class Inv[X](x: X) extends Expr { type T = X } +case class Inv2[X](x: X) extends Expr { type T >: X } + +def eval(e: Expr): e.T = e match + case Inv(x) => x // limitation // error + case Inv2(x) => x + diff --git a/tests/pos/pdgadt-wildcard.scala b/tests/pos/pdgadt-wildcard.scala deleted file mode 100644 index d7fc9511fdb3..000000000000 --- a/tests/pos/pdgadt-wildcard.scala +++ /dev/null @@ -1,20 +0,0 @@ -trait Expr { type T } -case class Inv[X](x: X) extends Expr { type T = X } -case class Inv2[X](x: X) extends Expr { type T >: X } - -def eval(e: Expr): e.T = e match - case Inv(x) => x - case Inv2(x) => x - -trait Foo[-T] -case class Bar() extends Foo[Int] - -def foo(e1: Expr, e2: Foo[e1.T]) = e1 match { - case Inv2(x) => e2 match { - case Bar() => - val t0: Int = x - val t1: e1.T = x - val t2: Int = t1 - } -} - From c01ff7f8f97e61b9eee77d6ea7aaeaae7ac51c7a Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 10:27:45 +0800 Subject: [PATCH 25/56] not registering aliasing type members --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 770a5fb962b0..dc8aa739253e 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -275,18 +275,9 @@ final class ProperGadtConstraint private( val denot1 = tp.nonPrivateMember(denot.name) val tb = denot.info - // We want to constrain type members whose bounds are type alias - // even if they are not deferred. - // - // For example: when constraining { type A } >:< { type A = Int } - // we want to take the bound (:= Int) of RHS into consideration. - def isTypeAlias: Boolean = tb match - case TypeAlias(_) => true - case _ => false - def nonPrivate: Boolean = !denot1.isInstanceOf[NoDenotation.type] - (denot1.symbol.is(Flags.Deferred) || isTypeAlias) + denot1.symbol.is(Flags.Deferred) && !denot1.symbol.is(Flags.Opaque) && !denot1.symbol.isClass && nonPrivate From 05ea7093ba17187c6bcdc17fe05f0a98610b21b9 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 10:28:00 +0800 Subject: [PATCH 26/56] update bounds desc --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index dc8aa739253e..eef6882382ca 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -413,7 +413,7 @@ final class ProperGadtConstraint private( val buf = new mutable.StringBuilder buf ++= "{\n" typeMembers foreach { denot => - buf ++= i" ${denot.symbol}: ${denot.info.bounds} ${denot.info.bounds.toString}\n" + buf ++= i" ${denot.symbol}: ${denot.info.bounds} [ ${denot.info.bounds.toString} ]\n" } buf ++= "}" buf.result From 7b99da4eb7db7fdb83c8e2be4ad5ad60a7e38b5d Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 10:28:13 +0800 Subject: [PATCH 27/56] use isSubType to do subtype reconstruction --- .../src/dotty/tools/dotc/core/PatternTypeConstrainer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 0365d6b2d6fc..148f9e39f823 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -211,11 +211,11 @@ trait PatternTypeConstrainer { self: TypeComparer => val patternType = TypeRef(patternPath, patternSymbol) def constrainSP = - ctx.gadt.addBound(scrutineePath, scrutineeSymbol, patternType, isUpper = true) + isSubType(scrutineeType, patternType) .showing(i"after $scrutineePath.$scrutineeSymbol <:< $patternType: result = $result, gadt = ${ctx.gadt.debugBoundsDescription}", gadts) def constrainPS = - ctx.gadt.addBound(patternPath, patternSymbol, scrutineeType, isUpper = true) + isSubType(patternType, scrutineeType) .showing(i"after $patternPath.$patternSymbol <:< $scrutineePath: result = $result, gadt = ${ctx.gadt.debugBoundsDescription}", gadts) constrainPS && constrainSP From bd50e7963ac6924bf501ec3cf31cf029489ec574 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 10:28:32 +0800 Subject: [PATCH 28/56] avoid false constraints in type inference when GADT mode is on --- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 6792e6adf130..7b4679e243ae 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -446,7 +446,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => false } || isSubTypeWhenFrozen(bounds(tp1).hi.boxed, tp2) || { - if (canConstrain(tp1) && !approx.high) + if canConstrain(tp1) && isPreciseBound(fromBelow = false) then addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound else thirdTry } @@ -613,6 +613,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling fourthTry } + def isPreciseBound(fromBelow: Boolean): Boolean = + if ctx.mode.is(Mode.GadtConstraintInference) then + !(approx.low || approx.high) + else + if fromBelow then !approx.low else !approx.high + def compareTypeParamRef(tp2: TypeParamRef): Boolean = assumedTrue(tp2) || { val alwaysTrue = @@ -626,7 +632,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling if (frozenConstraint) recur(tp1, bounds(tp2).lo.boxed) else isSubTypeWhenFrozen(tp1, tp2) alwaysTrue || { - if (canConstrain(tp2) && !approx.low) + if canConstrain(tp2) && isPreciseBound(fromBelow = true) then addConstraint(tp2, tp1.widenExpr, fromBelow = true) else fourthTry } From a406519140aa29bf337a5176fa5e14bed3470a8f Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 15:23:45 +0800 Subject: [PATCH 29/56] fix constrained type lookup in GadtConstraint --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index eef6882382ca..18a0cdcd057a 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -526,8 +526,8 @@ final class ProperGadtConstraint private( } val internalizedBound = bound match { - case nt: NamedType => - val ntTvar = mapping(nt.symbol) + case nt: TypeRef => + val ntTvar = tvarOf(nt) if (ntTvar != null) stripInternalTypeVar(ntTvar) else bound case _ => bound } From 0bd1a90de3573f70ce6488c835fd04936b36de38 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 15:24:14 +0800 Subject: [PATCH 30/56] add tracing for addConstraint in TypeComparer --- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 7b4679e243ae..c18b372475f4 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -447,7 +447,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } || isSubTypeWhenFrozen(bounds(tp1).hi.boxed, tp2) || { if canConstrain(tp1) && isPreciseBound(fromBelow = false) then - addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound + trace(i"addConstraint($tp1, <: $tp2)") { addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound } else thirdTry } compareTypeParamRef @@ -633,7 +633,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else isSubTypeWhenFrozen(tp1, tp2) alwaysTrue || { if canConstrain(tp2) && isPreciseBound(fromBelow = true) then - addConstraint(tp2, tp1.widenExpr, fromBelow = true) + trace(i"addConstriant($tp2, >: $tp1)") { addConstraint(tp2, tp1.widenExpr, fromBelow = true) } else fourthTry } } From e8d80b6f90b4fa44a557f0d74405146b6f61b176 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 15:24:29 +0800 Subject: [PATCH 31/56] refactor pathdep GADT constraining logic --- .../dotc/core/PatternTypeConstrainer.scala | 71 ++++++++++++------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 148f9e39f823..b45459f20f3a 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -10,6 +10,7 @@ import Contexts.ctx import dotty.tools.dotc.reporting.trace import config.Feature.migrateTo3 import config.Printers._ +import dotty.tools.dotc.core.SymDenotations.NoDenotation trait PatternTypeConstrainer { self: TypeComparer => @@ -175,6 +176,8 @@ trait PatternTypeConstrainer { self: TypeComparer => case tp => tp } + /** Reconstruct subtype constraints for type members. + */ def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { import NameKinds.DepParamName val realScrutineePath = ctx.gadt.scrutineePath @@ -191,36 +194,54 @@ trait PatternTypeConstrainer { self: TypeComparer => ctx.gadt.addEquality(scrutineePath, patternPath) - def registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) - def registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, + val registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) + val registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, // so it will always be un-registered at this point - val res = !registerScrutinee || !registerPattern || { - val scrutineeTypeMembers = Map.from { - ctx.gadt.registeredTypeMembers(scrutineePath) map { x => x.name -> x } - } - val patternTypeMembers = Map.from { - ctx.gadt.registeredTypeMembers(patternPath) map { x => x.name -> x } - } - - (scrutineeTypeMembers.keySet intersect patternTypeMembers.keySet) forall { name => - val scrutineeSymbol = scrutineeTypeMembers(name) - val patternSymbol = patternTypeMembers(name) - - val scrutineeType = TypeRef(scrutineePath, scrutineeSymbol) - val patternType = TypeRef(patternPath, patternSymbol) - def constrainSP = - isSubType(scrutineeType, patternType) - .showing(i"after $scrutineePath.$scrutineeSymbol <:< $patternType: result = $result, gadt = ${ctx.gadt.debugBoundsDescription}", gadts) + /** Reconstruct subtype constraints for a type member (with symbol `sym`) + of path `p`, given that `p` and `q` are cohabitated. + + There are three cases when we want to constrain the type member T of + path p and q: + + (1) q does not have number T. In this case we should simply return true. + + (2) q.T is a registered type member. We do SR on p.T <:< q.T, but not + q.T <:< p.T, since if q.T is also registered then `constrainTypeMember(q, p, T)` + will also be called, during which q.T <:< p.T will be handled. + + (3) q.T is unregistered. We will do SR on p.T <:< q.T and q.T <:< p.T. + */ + def constrainTypeMember(p: PathType, q: PathType, sym: Symbol): Boolean = { + def getMemberOrBounds(q: PathType, sym: Symbol): Option[TypeBounds | TypeRef] = + if ctx.gadt.contains(q, sym) then + Some(TypeRef(q, sym)) + else + val denot = q.member(sym.name) + if denot.isInstanceOf[NoDenotation.type] then + None + else + Some(denot.info.bounds) + + val pType = TypeRef(p, sym) + + if ctx.gadt.contains(q, sym) then + val tpr = TypeRef(q, sym) + isSubType(pType, tpr) + else + q.member(sym.name).isInstanceOf[NoDenotation.type] || { + val tpr = TypeRef(q, sym) + isSubType(pType, tpr) && isSubType(tpr, pType) + } + } - def constrainPS = - isSubType(patternType, scrutineeType) - .showing(i"after $patternPath.$patternSymbol <:< $scrutineePath: result = $result, gadt = ${ctx.gadt.debugBoundsDescription}", gadts) + def constrainPath(p: PathType, q: PathType) = + ctx.gadt.registeredTypeMembers(p) forall { sym => constrainTypeMember(p, q, sym) } + def constrainPS = constrainPath(patternPath, scrutineePath) + def constrainSP = constrainPath(scrutineePath, patternPath) - constrainPS && constrainSP - } - } + val res = (!registerPattern || constrainPS) && (!registerScrutinee || constrainSP) if !res then constraint = saved From fed4d4842d189d296590a63ad10561b4ccd3cc44 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 15:25:08 +0800 Subject: [PATCH 32/56] update pdgadt-wildcard test The limitation error in this test is eliminated. --- tests/{neg => pos}/pdgadt-wildcard.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/{neg => pos}/pdgadt-wildcard.scala (80%) diff --git a/tests/neg/pdgadt-wildcard.scala b/tests/pos/pdgadt-wildcard.scala similarity index 80% rename from tests/neg/pdgadt-wildcard.scala rename to tests/pos/pdgadt-wildcard.scala index a9ef72023e52..908ef4ff4fea 100644 --- a/tests/neg/pdgadt-wildcard.scala +++ b/tests/pos/pdgadt-wildcard.scala @@ -3,6 +3,6 @@ case class Inv[X](x: X) extends Expr { type T = X } case class Inv2[X](x: X) extends Expr { type T >: X } def eval(e: Expr): e.T = e match - case Inv(x) => x // limitation // error + case Inv(x) => x case Inv2(x) => x From 4bab5f09da362939d41ac0ba28ace78dea45cf1a Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 15:40:17 +0800 Subject: [PATCH 33/56] Revert "add tracing for addConstraint in TypeComparer" This reverts commit 6decfb3a7435d1b96e87afbc16b0e77e958debb1. --- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index c18b372475f4..7b4679e243ae 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -447,7 +447,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } || isSubTypeWhenFrozen(bounds(tp1).hi.boxed, tp2) || { if canConstrain(tp1) && isPreciseBound(fromBelow = false) then - trace(i"addConstraint($tp1, <: $tp2)") { addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound } + addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound else thirdTry } compareTypeParamRef @@ -633,7 +633,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else isSubTypeWhenFrozen(tp1, tp2) alwaysTrue || { if canConstrain(tp2) && isPreciseBound(fromBelow = true) then - trace(i"addConstriant($tp2, >: $tp1)") { addConstraint(tp2, tp1.widenExpr, fromBelow = true) } + addConstraint(tp2, tp1.widenExpr, fromBelow = true) else fourthTry } } From 2427004ff314056dc018eec6ca53aca2555193b5 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 15:46:30 +0800 Subject: [PATCH 34/56] cleanup and tweak tracing --- .../dotc/core/PatternTypeConstrainer.scala | 30 +++++++------------ 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index b45459f20f3a..9a1058eb99ff 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -213,28 +213,18 @@ trait PatternTypeConstrainer { self: TypeComparer => (3) q.T is unregistered. We will do SR on p.T <:< q.T and q.T <:< p.T. */ - def constrainTypeMember(p: PathType, q: PathType, sym: Symbol): Boolean = { - def getMemberOrBounds(q: PathType, sym: Symbol): Option[TypeBounds | TypeRef] = - if ctx.gadt.contains(q, sym) then - Some(TypeRef(q, sym)) - else - val denot = q.member(sym.name) - if denot.isInstanceOf[NoDenotation.type] then - None + def constrainTypeMember(p: PathType, q: PathType, sym: Symbol): Boolean = + q.member(sym.name).isInstanceOf[NoDenotation.type] || { + val pType = TypeRef(p, sym) + val qType = TypeRef(q, sym) + + trace(i"constrainTypeMember $pType >:< $qType", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + if ctx.gadt.contains(q, sym) then + isSubType(pType, qType) else - Some(denot.info.bounds) - - val pType = TypeRef(p, sym) - - if ctx.gadt.contains(q, sym) then - val tpr = TypeRef(q, sym) - isSubType(pType, tpr) - else - q.member(sym.name).isInstanceOf[NoDenotation.type] || { - val tpr = TypeRef(q, sym) - isSubType(pType, tpr) && isSubType(tpr, pType) + isSubType(pType, qType) && isSubType(qType, pType) } - } + } def constrainPath(p: PathType, q: PathType) = ctx.gadt.registeredTypeMembers(p) forall { sym => constrainTypeMember(p, q, sym) } From d9dd8ce8ed2ed87dc0ce7ff32f2a0e485da33d23 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 18:54:39 +0800 Subject: [PATCH 35/56] make singleton equality constriants actually working --- .../tools/dotc/core/GadtConstraint.scala | 71 ++++++++++++------- .../dotc/core/PatternTypeConstrainer.scala | 6 ++ .../dotty/tools/dotc/core/TypeComparer.scala | 9 +++ 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 18a0cdcd057a..87b8f60b0ef3 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -59,7 +59,7 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a path-dependent type already present in the constraint. */ def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean - def addEquality(p: PathType, q: PathType): Unit + def addEquality(p: PathType, q: PathType)(using Context): Unit def isEquivalent(p: PathType, q: PathType): Boolean @@ -552,22 +552,25 @@ final class ProperGadtConstraint private( } private def lookupPath(p: PathType): PathType | Null = - def recur(p: PathType, steps: Int = 0): PathType | Null = myUnionFind(p) match + def recur(p: PathType): PathType | Null = myUnionFind(p) match case null => null case q: PathType if q eq p => q case q: PathType => - if steps <= 1024 then - recur(q, steps + 1) - else - assert(false, "lookup step exceeding the threshold, possibly because of a loop in the union find") + recur(q) + recur(p) - override def addEquality(p: PathType, q: PathType): Unit = - val newRep: PathType = lookupPath(p) match - case null => lookupPath(q) match - case null => p - case r: PathType => r - case r: PathType => r + override def addEquality(p: PathType, q: PathType)(using Context): Unit = + val pRep: PathType | Null = lookupPath(p) + val qRep: PathType | Null = lookupPath(q) + + val newRep = (pRep, qRep) match + case (null, null) => p + case (null, r: PathType) => r + case (r: PathType, null) => r + case (r1: PathType, r2: PathType) => + myUnionFind = myUnionFind.updated(r2, r1) + r1 myUnionFind = myUnionFind.updated(p, newRep) myUnionFind = myUnionFind.updated(q, newRep) @@ -700,15 +703,26 @@ final class ProperGadtConstraint private( if myPatternSkolem eq null then () else - pathDepMapping(myPatternSkolem.nn) match { - case null => - case m: SimpleIdentityMap[Symbol, TypeVar] => - pathDepMapping = pathDepMapping.updated(path, m) - m foreachBinding { (sym, tvar) => - val tpr = tvar.origin - pathDepReverseMapping = pathDepReverseMapping.updated(tpr, TypeRef(path, sym)) - } - } + def updateMappings() = + pathDepMapping(myPatternSkolem.nn) match { + case null => + case m: SimpleIdentityMap[Symbol, TypeVar] => + pathDepMapping = pathDepMapping.updated(path, m) + m foreachBinding { (sym, tvar) => + val tpr = tvar.origin + pathDepReverseMapping = pathDepReverseMapping.updated(tpr, TypeRef(path, sym)) + } + } + + def updateUnionFind() = + myUnionFind(myPatternSkolem.nn) match { + case null => + case repr: PathType => + myUnionFind = myUnionFind.updated(path, repr) + } + + updateMappings() + updateUnionFind() end supplyPatternPath override def createPatternSkolem(pat: Type): SkolemType = @@ -799,16 +813,25 @@ final class ProperGadtConstraint private( override def debugBoundsDescription(using Context): String = { val sb = new mutable.StringBuilder sb ++= constraint.show - sb += '\n' + sb ++= "\nType parameter bounds:\n" mapping.foreachBinding { case (sym, _) => sb ++= i"$sym: ${fullBounds(sym)}\n" } - sb += '\n' + sb ++= "\nPath-dependent type bounds:\n" pathDepMapping foreachBinding { case (path, m) => m foreachBinding { case (sym, _) => sb ++= i"$path.$sym: ${fullBounds(TypeRef(path, sym))}\n" } } + sb ++= "\nSingleton equalities:\n" + myUnionFind foreachBinding { case (path, _) => + val repr = lookupPath(path) + repr match + case repr: PathType if repr ne path => + sb ++= i"$path.type: $repr.type\n" + case _ => + } + sb.result } } @@ -844,7 +867,7 @@ final class ProperGadtConstraint private( override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") - override def addEquality(p: PathType, q: PathType) = () + override def addEquality(p: PathType, q: PathType)(using Context) = () override def isEquivalent(p: PathType, q: PathType) = false diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 9a1058eb99ff..fdb446e7332f 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -194,6 +194,12 @@ trait PatternTypeConstrainer { self: TypeComparer => ctx.gadt.addEquality(scrutineePath, patternPath) + pat match { + case ptPath: TermRef => + ctx.gadt.addEquality(scrutineePath, ptPath) + case _ => + } + val registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) val registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, // so it will always be un-registered at this point diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 7b4679e243ae..58b9b7107246 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -590,6 +590,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling || fourthTry case _ => + def compareSingletonGADT: Boolean = + (tp1, tp2) match { + case (tp1: TermRef, tp2: TermRef) => ctx.gadt.isEquivalent(tp1, tp2) + case _ => false + } + val cls2 = tp2.symbol if (cls2.isClass) if (cls2.typeParams.isEmpty) { @@ -610,6 +616,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling } else if tp1.isLambdaSub && !tp1.isAnyKind then return recur(tp1, EtaExpansion(tp2)) + + if compareSingletonGADT then return true + fourthTry } From b50800ebe39300ba36c966e20e8bbd4a45c728ef Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 20:05:54 +0800 Subject: [PATCH 36/56] add examples for singleton equality constraints --- tests/neg/pdgadt-reuse.scala | 16 ++++++++++++++++ tests/neg/pdgadt-singletons.scala | 30 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 tests/neg/pdgadt-reuse.scala create mode 100644 tests/neg/pdgadt-singletons.scala diff --git a/tests/neg/pdgadt-reuse.scala b/tests/neg/pdgadt-reuse.scala new file mode 100644 index 000000000000..67444f5ab661 --- /dev/null +++ b/tests/neg/pdgadt-reuse.scala @@ -0,0 +1,16 @@ +object test { + trait Result[A] + + def cached[A](f: (x: A) => Result[x.type], toCache: A): (x: A) => Result[x.type] = { + val cachedRes: Result[toCache.type] = f(toCache) + def resFunc(x: A): Result[x.type] = { + x match { + case r: toCache.type => cachedRes + case _ => + val res: Result[x.type] = cachedRes // error + f(x) + } + } + resFunc + } +} diff --git a/tests/neg/pdgadt-singletons.scala b/tests/neg/pdgadt-singletons.scala new file mode 100644 index 000000000000..24650807bd2e --- /dev/null +++ b/tests/neg/pdgadt-singletons.scala @@ -0,0 +1,30 @@ +object test { + trait Tag + + locally { + val p0: Tag = ??? + val p1: Tag = ??? + + p0 match { + case q0: Tag => + val x1: p0.type = q0 + val x2: q0.type = p0 + val x3: p0.type = p1 // error + + p1 match { + case q1: Tag => + val x1: p1.type = q1 + val x2: q1.type = p1 + val x3: p0.type = p1 // error + + q0 match { + case r: q1.type => + val x1: q0.type = q1 + val x2: p0.type = p1 + val x3: q0.type = p1 + val x4: q1.type = p0 + } + } + } + } +} From a26e3e654adfe63d4184895a90e51f952b6d7b5d Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 20:31:37 +0800 Subject: [PATCH 37/56] add GADT usage info when singleton equality constraints are used --- compiler/src/dotty/tools/dotc/core/TypeComparer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 58b9b7107246..3e7b9f9c2611 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -592,7 +592,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => def compareSingletonGADT: Boolean = (tp1, tp2) match { - case (tp1: TermRef, tp2: TermRef) => ctx.gadt.isEquivalent(tp1, tp2) + case (tp1: TermRef, tp2: TermRef) => + ctx.gadt.isEquivalent(tp1, tp2) && { GADTused = true; true } case _ => false } From 7e5618eededf73d5f84f36ae3f22f3152a65b61b Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 20:32:11 +0800 Subject: [PATCH 38/56] support subtype reconstruction when pattern is an alias to another path This commits additionally reconstruct subtype for p >:< r in: p match { case _: r.type => ... } --- .../dotc/core/PatternTypeConstrainer.scala | 92 ++++++++++++------- 1 file changed, 58 insertions(+), 34 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index fdb446e7332f..7b1ed32b0e98 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -181,63 +181,87 @@ trait PatternTypeConstrainer { self: TypeComparer => def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { import NameKinds.DepParamName val realScrutineePath = ctx.gadt.scrutineePath + /* We reset scrutinee path so that the path will only be used at top level. */ ctx.gadt.resetScrutineePath() + val saved = state.nn.constraint + val savedGadt = ctx.gadt.fresh + val scrutineePath: TermRef | SkolemType = realScrutineePath match case null => SkolemType(scrut) case _ => realScrutineePath val patternPath: SkolemType = ctx.gadt.createPatternSkolem(pat) - val saved = state.nn.constraint - val savedGadt = ctx.gadt.fresh - - ctx.gadt.addEquality(scrutineePath, patternPath) - - pat match { - case ptPath: TermRef => - ctx.gadt.addEquality(scrutineePath, ptPath) - case _ => - } - val registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) val registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, // so it will always be un-registered at this point + /** Reconstruct subtype constraints for a path `p`, given that `p` and `q` + are cohabitated. - /** Reconstruct subtype constraints for a type member (with symbol `sym`) - of path `p`, given that `p` and `q` are cohabitated. - - There are three cases when we want to constrain the type member T of - path p and q: + When do SR for each type member (denoted as `T`) of path p, there are the + following three cases: (1) q does not have number T. In this case we should simply return true. (2) q.T is a registered type member. We do SR on p.T <:< q.T, but not - q.T <:< p.T, since if q.T is also registered then `constrainTypeMember(q, p, T)` - will also be called, during which q.T <:< p.T will be handled. + q.T <:< p.T, since if q.T is also registered then + `constrainTypeMember(q, p, T)` will also be called, during which + q.T <:< p.T will be handled. (3) q.T is unregistered. We will do SR on p.T <:< q.T and q.T <:< p.T. */ - def constrainTypeMember(p: PathType, q: PathType, sym: Symbol): Boolean = - q.member(sym.name).isInstanceOf[NoDenotation.type] || { - val pType = TypeRef(p, sym) - val qType = TypeRef(q, sym) - - trace(i"constrainTypeMember $pType >:< $qType", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { - if ctx.gadt.contains(q, sym) then - isSubType(pType, qType) - else - isSubType(pType, qType) && isSubType(qType, pType) + def reconstructSubTypeFor(p: PathType, q: PathType) = + def processMember(sym: Symbol): Boolean = + q.member(sym.name).isInstanceOf[NoDenotation.type] || { + val pType = TypeRef(p, sym) + val qType = TypeRef(q, sym) + + trace(i"constrainTypeMember $pType >:< $qType", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + if ctx.gadt.contains(q, sym) then + isSubType(pType, qType) + else + isSubType(pType, qType) && isSubType(qType, pType) + } } - } - def constrainPath(p: PathType, q: PathType) = - ctx.gadt.registeredTypeMembers(p) forall { sym => constrainTypeMember(p, q, sym) } - def constrainPS = constrainPath(patternPath, scrutineePath) - def constrainSP = constrainPath(scrutineePath, patternPath) + ctx.gadt.registeredTypeMembers(p) forall { sym => processMember(sym) } + + /** Reconstruct subtype from the cohabitation between the scrutinee and the + pattern. */ + def constrainPattern: Boolean = { + ctx.gadt.addEquality(scrutineePath, patternPath) + + (!registerPattern || reconstructSubTypeFor(patternPath, scrutineePath)) + && (!registerScrutinee || reconstructSubTypeFor(scrutineePath, patternPath)) + } + + /** Reconstruct subtype when the pattern is an alias to another path. + + For example, consider the following pattern match: + + p match + case q: r.type => + + Then we can also reconstruct subtype from the cohabitation of p and r. + */ + def maybeConstrainPatternAlias: Boolean = pat match { + case ptPath: TermRef => + val registerPtPath = ctx.gadt.contains(ptPath) || ctx.gadt.addToConstraint(ptPath) + + val result = + (!registerPtPath || reconstructSubTypeFor(ptPath, scrutineePath)) + && (!registerScrutinee || reconstructSubTypeFor(scrutineePath, ptPath)) + + ctx.gadt.addEquality(scrutineePath, ptPath) + + result + case _ => + true + } - val res = (!registerPattern || constrainPS) && (!registerScrutinee || constrainSP) + val res = constrainPattern && maybeConstrainPatternAlias if !res then constraint = saved From cab8dcbf2e9e27ae6765ecc2bd6f48130f084e4f Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 20:34:31 +0800 Subject: [PATCH 39/56] add test for constraining aliasing pattern --- tests/neg/pdgadt-patalias.scala | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/neg/pdgadt-patalias.scala diff --git a/tests/neg/pdgadt-patalias.scala b/tests/neg/pdgadt-patalias.scala new file mode 100644 index 000000000000..1a6dbcc44ccb --- /dev/null +++ b/tests/neg/pdgadt-patalias.scala @@ -0,0 +1,38 @@ +trait Expr { type X } +trait IntExpr { type X = Int } + +val iexpr: Expr { type X = Int } = new Expr { type X = Int } + +object Zero { type X = 0 } + +def direct1(x: Expr) = x match { + case _: iexpr.type => + val x1: Int = ??? : x.X + val x2: x.X = ??? : Int + val x3: iexpr.type = x +} + +def direct2(x: Expr) = x match { + case _: Zero.type => + val x1: Int = ??? : x.X // limitation // error + val x2: x.X = 0 // limitation // error +} + +def indirect1(a: Expr, b: IntExpr) = a match { + case r: b.type => + val x1: Int = ??? : a.X + val x2: a.X = ??? : Int + val x3: a.type = b +} + +def indirect2(a: Expr, b: Expr) = a match { + case _: IntExpr => + // sanity check + val x1: Int = ??? : a.X + val x2: a.X = ??? : Int + b match { + case _: a.type => + val x1: Int = ??? : b.X + val x2: b.X = ??? : Int + } +} From 130ee6cc98f6b61d1fd63c82cc4da1ccc395ae37 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 10 Aug 2022 20:41:10 +0800 Subject: [PATCH 40/56] cleanup comments --- compiler/src/dotty/tools/dotc/core/GadtConstraint.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 87b8f60b0ef3..88a81ef551d4 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -138,15 +138,6 @@ final class ProperGadtConstraint private( myPatternSkolem = null, ) - // /** Exposes ConstraintHandling.subsumes */ - // def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { - // def extractConstraint(g: GadtConstraint) = g match { - // case s: ProperGadtConstraint => s.constraint - // case EmptyGadtConstraint => OrderingConstraint.empty - // } - // subsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) - // } - /** Whether `left` subsumes `right`? * * `left` and `right` both stem from the constraint `pre`, with different type reasoning performed, From 7136b408114b159fb8e0f7301522fa26d13b99ed Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 14 Aug 2022 11:24:18 +0800 Subject: [PATCH 41/56] improve the documentation of `GadtConstraint.subsumes` --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 88a81ef551d4..381b9d3a034e 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -140,9 +140,17 @@ final class ProperGadtConstraint private( /** Whether `left` subsumes `right`? * - * `left` and `right` both stem from the constraint `pre`, with different type reasoning performed, - * during which new types might be registered in GadtConstraint. This function will take such newly - * registered types into consideration. + * `left` and `right` branches from `pre` with different constraint reasoning + * performed. During this, new path-dependent types could be registered in `left` + * and `right`. + * + * This function considers the cases where both `left` and `right` register a + * new path-depepdent type p.T. The problem is that in this case p.T will have + * two different internal representation in `left` and `right`: + * - p.T registered in `left` and has the internal representation T(param)$1 + * - p.T registered in `right` and has the internal representation T(param)$2 + * In `subsumes`, we have to recognize the fact that both T(param)$1 and T(param)$2 + * represents the same path-dependent type, to give the correct result. */ def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { def checkSubsumes(c1: Constraint, c2: Constraint, pre: Constraint): Boolean = { From e2d5f6ff9183b0baa5393d3f5ea76fac60c89a3f Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 14 Aug 2022 15:36:31 +0800 Subject: [PATCH 42/56] more documentation for GadtConstraint class --- .../tools/dotc/core/GadtConstraint.scala | 143 ++++++++---------- 1 file changed, 61 insertions(+), 82 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 381b9d3a034e..590bc65ec469 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -59,10 +59,13 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a path-dependent type already present in the constraint. */ def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean + /** Record the equality between two singleton types. */ def addEquality(p: PathType, q: PathType)(using Context): Unit + /** Check whether two singleton types are equivalent. */ def isEquivalent(p: PathType, q: PathType): Boolean + /** Query the representative member of a singleton type. */ def reprOf(p: PathType): PathType | Null /** Scrutinee path of the current pattern matching. */ @@ -180,7 +183,8 @@ final class ProperGadtConstraint private( } checkNewParams && { - // compute mappings between the newly-registered type params in the two branches + // Computes mappings between the newly-registered type params in + // the two branches. def createMappings = { var mapping1: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty var mapping2: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty @@ -207,7 +211,7 @@ final class ProperGadtConstraint private( (mapTypeParam(mapping1), mapTypeParam(mapping2)) } - // bridge between the newly-registered types in c2 and c1 + // Bridge between the newly-registered types in c2 and c1 val (mapping1, mapping2) = createMappings try { @@ -263,39 +267,47 @@ final class ProperGadtConstraint private( => true case _ => false + /** Check whether a type member is constrainable based on its denotation. + * + * A type member is considered constrainable if it is abstract, is not + * an opaque type, is not a class and is non-private. + */ + private def isConstrainableDenot(denot: Denotation)(using Context): Boolean = + denot.symbol.is(Flags.Deferred) + && !denot.symbol.is(Flags.Opaque) + && !denot.symbol.isClass + && !denot.isInstanceOf[NoDenotation.type] + /** Find all constrainable type member denotations of the given type. * - * All abstract but not opaque type members are returned. * Note that we return denotation here, since the bounds of the type member * depend on the context (e.g. applied type parameters). */ private def constrainableTypeMembers(tp: Type)(using Context): List[Denotation] = tp.typeMembers.toList filter { denot => val denot1 = tp.nonPrivateMember(denot.name) - val tb = denot.info - - def nonPrivate: Boolean = !denot1.isInstanceOf[NoDenotation.type] - - denot1.symbol.is(Flags.Deferred) - && !denot1.symbol.is(Flags.Opaque) - && !denot1.symbol.isClass - && nonPrivate + isConstrainableDenot(denot1) } + /** Check whether a type member of a path is constrainable. */ private def isConstrainableTypeMember(path: PathType, sym: Symbol)(using Context): Boolean = val mbr = path.nonPrivateMember(sym.name) mbr.isInstanceOf[SingleDenotation] && { val denot1 = path.nonPrivateMember(mbr.name) - val tb = mbr.info - - denot1.symbol.is(Flags.Deferred) - && !denot1.symbol.is(Flags.Opaque) - && !denot1.symbol.isClass + isConstrainableDenot(denot1) } + /** Check whether a path-dependent type is constrainable. + * + * A path-dependent type p.A is constrainable if its path p and the type member A is + * constrainable. + */ override def isConstrainablePDT(path: PathType, sym: Symbol)(using Context): Boolean = isConstrainablePath(path) && isConstrainableTypeMember(path, sym) + /** Get the internal type variable of the path-dependent type. Return null if it + * is not registered. + */ private def tvarOf(path: PathType, sym: Symbol)(using Context): TypeVar | Null = pathDepMapping(path) match case null => null @@ -310,19 +322,17 @@ final class ProperGadtConstraint private( tpr match case TypeRef(p: PathType, _) => tvarOf(p, tpr.symbol) case _ => null - case tv => tv - - /** Try to retrieve the internal type variable for a NamedType. */ - private def tvarOf(ntp: NamedType)(using Context): TypeVar | Null = - ntp match - case tp: TypeRef => tvarOf(tp) - case _ => null + case tv: TypeVar => tv private def tvarOf(tp: Type)(using Context): TypeVar | Null = tp match case tp: TypeRef => tvarOf(tp) case _ => null + /** Register all constrainable path-dependent types addressed from the path. + * Returns whether the registration succeeds. It also checks whether the path + * itself is constrainable. + */ override def addToConstraint(path: PathType)(using Context): Boolean = isConstrainablePath(path) && { import NameKinds.DepParamName val pathType = path.widen @@ -384,27 +394,24 @@ final class ProperGadtConstraint private( pt => defn.AnyType ) - def register: Boolean = - val tvars = typeMemberSymbols lazyZip poly1.paramRefs map { (sym, paramRef) => - val tv = TypeVar(paramRef, creatorState = null) + val tvars = typeMemberSymbols lazyZip poly1.paramRefs map { (sym, paramRef) => + val tv = TypeVar(paramRef, creatorState = null) - val externalType = TypeRef(path, sym) - pathDepMapping = pathDepMapping.updated(path, { - val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match - case null => SimpleIdentityMap.empty - case m => m.nn + val externalType = TypeRef(path, sym) + pathDepMapping = pathDepMapping.updated(path, { + val old: SimpleIdentityMap[Symbol, TypeVar] = pathDepMapping(path) match + case null => SimpleIdentityMap.empty + case m => m.nn - old.updated(sym, tv) - }) - pathDepReverseMapping = pathDepReverseMapping.updated(tv.origin, externalType) + old.updated(sym, tv) + }) + pathDepReverseMapping = pathDepReverseMapping.updated(tv.origin, externalType) - tv - } - - addToConstraint(poly1, tvars) - .showing(i"added to constraint: [$poly1] $path\n$debugBoundsDescription", gadts) + tv + } - register + addToConstraint(poly1, tvars) + .showing(i"added to constraint: [$poly1] $path\n$debugBoundsDescription", gadts) } } @@ -417,6 +424,7 @@ final class ProperGadtConstraint private( buf ++= "}" buf.result + /** Get the representative member of the path in the union find. */ override def reprOf(p: PathType): PathType | Null = lookupPath(p) override def addToConstraint(params: List[Symbol])(using Context): Boolean = { @@ -468,7 +476,7 @@ final class ProperGadtConstraint private( .showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts) } - override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + private def addBoundForTvar(tvar: TypeVar, bound: Type, isUpper: Boolean, typeRepr: String)(using Context): Boolean = { @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { case tv: TypeVar => val inst = constraint.instType(tv) @@ -476,10 +484,10 @@ final class ProperGadtConstraint private( case _ => tp } - val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(path, sym)) match { + val symTvar: TypeVar = stripInternalTypeVar(tvar) match { case tv: TypeVar => tv case inst => - gadts.println(i"instantiated: $path.$sym -> $inst") + gadts.println(i"instantiated: $typeRepr -> $inst") return if (isUpper) isSub(inst, bound) else isSub(bound, inst) } @@ -502,52 +510,23 @@ final class ProperGadtConstraint private( gadts.println { val descr = if (isUpper) "upper" else "lower" val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $path.$sym $op $bound = $result" + i"adding $descr bound $typeRepr $op $bound = $result" } if constraint ne saved then wasConstrained = true result } - override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = constraint.instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } - - val symTvar: TypeVar = stripInternalTypeVar(tvarOrError(sym)) match { - case tv: TypeVar => tv - case inst => - gadts.println(i"instantiated: $sym -> $inst") - return if (isUpper) isSub(inst, bound) else isSub(bound, inst) - } - - val internalizedBound = bound match { - case nt: TypeRef => - val ntTvar = tvarOf(nt) - if (ntTvar != null) stripInternalTypeVar(ntTvar) else bound - case _ => bound - } - - val saved = constraint - val result = internalizedBound match - case boundTvar: TypeVar => - if (boundTvar eq symTvar) true - else if (isUpper) addLess(symTvar.origin, boundTvar.origin) - else addLess(boundTvar.origin, symTvar.origin) - case bound => - addBoundTransitively(symTvar.origin, bound, isUpper) - - gadts.println { - val descr = if (isUpper) "upper" else "lower" - val op = if (isUpper) "<:" else ">:" - i"adding $descr bound $sym $op $bound = $result" - } + override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + val tvar = tvarOrError(path, sym) + val typeRepr = TypeRef(path, sym).show + addBoundForTvar(tvar, bound, isUpper, typeRepr) + } - if constraint ne saved then wasConstrained = true - result + override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = { + val tvar = tvarOrError(sym) + val typeRepr = sym.typeRef.show + addBoundForTvar(tvar, bound, isUpper, typeRepr) } private def lookupPath(p: PathType): PathType | Null = From 8371bad232229b95b4ba1f593a2841d7e8d5efbe Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 14 Aug 2022 15:42:11 +0800 Subject: [PATCH 43/56] Minor --- .../src/dotty/tools/dotc/core/PatternTypeConstrainer.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 7b1ed32b0e98..4378b89e438f 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -176,8 +176,7 @@ trait PatternTypeConstrainer { self: TypeComparer => case tp => tp } - /** Reconstruct subtype constraints for type members. - */ + /** Reconstruct subtype constraints for type members of the scrutinee and the pattern. */ def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { import NameKinds.DepParamName val realScrutineePath = ctx.gadt.scrutineePath From a1d0e49fb8d4b5ed4044d821fa72e20a2bab9c6e Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 14 Aug 2022 17:03:58 +0800 Subject: [PATCH 44/56] also try to register path-dependent types in the bounds --- .../dotty/tools/dotc/core/TypeComparer.scala | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 3e7b9f9c2611..e8532310fc6b 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2084,6 +2084,24 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = { val boundImprecise = approx.high || approx.low ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && { + def tryRegisterBound: Boolean = bound.match { + case tr @ TypeRef(path: PathType, _) => + val sym = tr.symbol + + def register = + ctx.gadt.contains(path, sym) || ctx.gadt.contains(sym) || { + ctx.gadt.isConstrainablePDT(path, tr.symbol) && { + gadts.println(i"!!! registering path on the fly path=$path sym=$sym") + ctx.gadt.addToConstraint(path) && ctx.gadt.contains(path, sym) + } + } + + val result = register + + true + case _ => true + } + def narrowTypeParams = ctx.gadt.contains(tr.symbol) && { val tparam = tr.symbol gadts.println(i"narrow gadt bound of tparam $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") @@ -2125,7 +2143,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case _ => false - narrowTypeParams || narrowPathDepType + tryRegisterBound && narrowTypeParams || narrowPathDepType } } From ac4f6e7a80ad86130b8a006efe26cf68dd7db18f Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Sun, 14 Aug 2022 17:04:31 +0800 Subject: [PATCH 45/56] improve tracing in GadtConstraint --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 590bc65ec469..89ddc925f06e 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -338,9 +338,13 @@ final class ProperGadtConstraint private( val pathType = path.widen val typeMembers = constrainableTypeMembers(path).filterNot(_.symbol eq NoSymbol) - gadts.println(i"> trying to add $path into constraint ...") - gadts.println(i" path.widen = $pathType") - gadts.println(i" type members =\n${debugShowTypeMembers(typeMembers)}") + gadts.println { + val sb = new mutable.StringBuilder() + sb ++= i"* trying to add $path into constraint ...\n" + sb ++= i"** path.widen = $pathType\n" + sb ++= i"** type members =\n${debugShowTypeMembers(typeMembers)}\n" + sb.result() + } typeMembers.nonEmpty && { val typeMemberSymbols: List[Symbol] = typeMembers map { x => x.symbol } @@ -411,7 +415,7 @@ final class ProperGadtConstraint private( } addToConstraint(poly1, tvars) - .showing(i"added to constraint: [$poly1] $path\n$debugBoundsDescription", gadts) + .showing(i"added to constraint: [$poly1] $path, result = $result\n$debugBoundsDescription", gadts) } } From 6843cb158e6a16bf4d24fa71fedb1101ffa305f6 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 15 Aug 2022 15:55:41 +0800 Subject: [PATCH 46/56] substitute instantiated dependent params --- .../tools/dotc/core/GadtConstraint.scala | 39 ++++++++++++++----- tests/pos/gadt-dep-param.scala | 12 ++++++ 2 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 tests/pos/gadt-dep-param.scala diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 89ddc925f06e..5165304e3394 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -357,17 +357,22 @@ final class ProperGadtConstraint private( tp.derivedAndOrType(loop(tp1), loop(tp2)) case tp @ OrType(tp1, tp2) if isUpper => tp.derivedOrType(loop(tp1), loop(tp2)) - case tp @ TypeRef(prefix, des) if prefix eq path => + case tp @ TypeRef(prefix, _) if prefix eq path => typeMemberSymbols indexOf tp.symbol match case -1 => tp case idx => pt.paramRefs(idx) - case tp @ TypeRef(_: RecThis, des) => + case tp @ TypeRef(_: RecThis, _) => typeMemberSymbols indexOf tp.symbol match case -1 => tp case idx => pt.paramRefs(idx) case tp: TypeRef => tvarOf(tp) match { - case tv: TypeVar => tv.origin + case tv: TypeVar => + val tp = stripInternalTypeVar(tv) + tp.match { + case tv1: TypeVar => stripTypeVarWhenDependent(tv1) + case tp => tp + } case null => tp } case tp => tp @@ -450,7 +455,12 @@ final class ProperGadtConstraint private( params.indexOf(tp.symbol) match { case -1 => mapping(tp.symbol) match { - case tv: TypeVar => tv.origin + case tv: TypeVar => + val tp = stripInternalTypeVar(tv) + tp.match { + case tv1: TypeVar => stripTypeVarWhenDependent(tv1) + case tp => tp + } case null => tp } case i => pt.paramRefs(i) @@ -480,13 +490,22 @@ final class ProperGadtConstraint private( .showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts) } + @annotation.tailrec private def stripInternalTypeVar(tp: Type): Type = tp match { + case tv: TypeVar => + val inst = constraint.instType(tv) + if (inst.exists) stripInternalTypeVar(inst) else tv + case _ => tp + } + + private def stripTypeVarWhenDependent(tv: TypeVar): TypeParamRef | TypeVar = + val tpr = tv.origin + if constraint.contains(tpr) then + tpr + else + tv + end stripTypeVarWhenDependent + private def addBoundForTvar(tvar: TypeVar, bound: Type, isUpper: Boolean, typeRepr: String)(using Context): Boolean = { - @annotation.tailrec def stripInternalTypeVar(tp: Type): Type = tp match { - case tv: TypeVar => - val inst = constraint.instType(tv) - if (inst.exists) stripInternalTypeVar(inst) else tv - case _ => tp - } val symTvar: TypeVar = stripInternalTypeVar(tvar) match { case tv: TypeVar => tv diff --git a/tests/pos/gadt-dep-param.scala b/tests/pos/gadt-dep-param.scala new file mode 100644 index 000000000000..cbdd6db338a5 --- /dev/null +++ b/tests/pos/gadt-dep-param.scala @@ -0,0 +1,12 @@ +enum Tup[A, B]: + case Data[A, B]() extends Tup[A, B] + +def foo1[A, B](e: Tup[A, B]) = e.match { + case _: Tup.Data[a, b] => + def bar[C >: a <: b, D]() = + val t1: b = ??? : a +} + +def foo2[A, B, C >: A <: B]() = + val t1: B = ??? : A + From 3edd4a0bfabb9c1178e7fdc9d61a26e1edc38fee Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 15 Aug 2022 16:04:54 +0800 Subject: [PATCH 47/56] add pdgadt-sub pos test --- tests/pos/pdgadt-sub.scala | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/pos/pdgadt-sub.scala diff --git a/tests/pos/pdgadt-sub.scala b/tests/pos/pdgadt-sub.scala new file mode 100644 index 000000000000..6b29ac0bb306 --- /dev/null +++ b/tests/pos/pdgadt-sub.scala @@ -0,0 +1,21 @@ +trait SubBase { + type A0 + type B0 + type C >: A0 <: B0 +} +trait Tag { type A } + +type Sub[A, B] = SubBase { type A0 = A; type B0 = B } + +def foo(x: Tag, y: Tag, e: Sub[x.A, y.A]) = e.match { + case _: Object => + val t1: y.A = ??? : x.A +} + +def bar(x: Tag, e: Sub[Int, x.A]): x.A = e match { + case _: Object => 0 +} + +def baz(x: Tag, e: Sub[x.A, Int]): Int = e match { + case _: Object => ??? : x.A +} From 6908bfc62de539618cbe966ec49a0ab492cf9613 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 15 Aug 2022 16:51:37 +0800 Subject: [PATCH 48/56] refactor subsumes --- .../tools/dotc/core/GadtConstraint.scala | 116 ++++++------------ 1 file changed, 37 insertions(+), 79 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 5165304e3394..c0505b86585d 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -155,95 +155,53 @@ final class ProperGadtConstraint private( * In `subsumes`, we have to recognize the fact that both T(param)$1 and T(param)$2 * represents the same path-dependent type, to give the correct result. */ - def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { - def checkSubsumes(c1: Constraint, c2: Constraint, pre: Constraint): Boolean = { - if (c2 eq pre) true - else if (c1 eq pre) false - else { - val saved = constraint - - def computeNewParams = - val params1 = c1.domainParams.toSet - val params2 = c2.domainParams.toSet - val preParams = pre.domainParams.toSet - /** Type parameter registered after branching */ - (params1.diff(preParams), params2.diff(preParams)) - - val (newParams1, newParams2) = computeNewParams - - // When new types are registered after pre, for left to subsume right, it should contain all types - // newly registered in right. - def checkNewParams: Boolean = (left, right) match { - case (left: ProperGadtConstraint, right: ProperGadtConstraint) => - newParams2 forall { p2 => - val tp2 = right.externalize(p2) - left.tvarOf(tp2) != null + def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = + def checkSubsumes(left: ProperGadtConstraint, right: ProperGadtConstraint, pre: ProperGadtConstraint): Boolean = { + def rightToLeft: TypeParamRef => TypeParamRef = { + val preParams = pre.constraint.domainParams.toSet + val mapping = { + var res: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + + right.constraint.domainParams.foreach { p2 => + left.tvarOf(right.externalize(p2)) match { + case null => + case tv: TypeVar => + res = res.updated(p2, tv.origin) } - case _ => true - } - - checkNewParams && { - // Computes mappings between the newly-registered type params in - // the two branches. - def createMappings = { - var mapping1: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty - var mapping2: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty - - (left, right) match { - case (left: ProperGadtConstraint, right: ProperGadtConstraint) => - newParams1 foreach { p1 => - val tp1 = left.externalize(p1) - right.tvarOf(tp1) match { - case null => - case tvar2: TypeVar => - mapping1 = mapping1.updated(p1, tvar2.origin) - mapping2 = mapping2.updated(tvar2.origin, p1) - } - } - case _ => - } - - def mapTypeParam(m: SimpleIdentityMap[TypeParamRef, TypeParamRef])(tpr: TypeParamRef) = - m(tpr) match - case null => tpr - case tpr1: TypeParamRef => tpr1 - - (mapTypeParam(mapping1), mapTypeParam(mapping2)) } - // Bridge between the newly-registered types in c2 and c1 - val (mapping1, mapping2) = createMappings + res + } - try { - // checks existing type parameters in `pre` - def existing: Boolean = pre.forallParams { p => - c1.contains(p) && - c2.upper(p).forall { q => - c1.isLess(p, mapping1(q)) - } && isSubTypeWhenFrozen(c1.nonParamBounds(p), c2.nonParamBounds(p)) - } + def func(p2: TypeParamRef) = + if pre.constraint.domainParams contains p2 then p2 + else mapping(p2) - // checks new type parameters in `c1` - def added: Boolean = newParams1 forall { p1 => - val p2 = mapping1(p1) - c2.upper(p2).forall { q => - c1.isLess(p1, mapping2(q)) - } && isSubTypeWhenFrozen(c1.nonParamBounds(p1), c2.nonParamBounds(p2)) - } + func + } - existing && checkNewParams && added - } finally constraint = saved + def checkParam(p2: TypeParamRef) = + rightToLeft(p2).match { + case null => false + case p1: TypeParamRef => + left.constraint.entry(p1).exists + && right.constraint.upper(p1).map(rightToLeft).forall(left.constraint.isLess(p1, _)) + && isSubTypeWhenFrozen(left.constraint.nonParamBounds(p1), right.constraint.nonParamBounds(p2)) } - } - } - def extractConstraint(g: GadtConstraint) = g match { - case s: ProperGadtConstraint => s.constraint - case EmptyGadtConstraint => OrderingConstraint.empty + def todos: Set[TypeParamRef] = + right.constraint.domainParams.toSet ++ pre.constraint.domainParams + + todos.forall(checkParam) } - checkSubsumes(extractConstraint(left), extractConstraint(right), extractConstraint(pre)) - } + (left, right, pre) match { + case (left: ProperGadtConstraint, right: ProperGadtConstraint, pre: ProperGadtConstraint) => + checkSubsumes(left, right, pre) + case (_, EmptyGadtConstraint, _) => true + case (EmptyGadtConstraint, _, _) => false + case (_, _, EmptyGadtConstraint) => false + } override protected def legalBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Type = // GADT constraints never involve wildcards and are not propagated outside From a6f3a3aa54c23f873634a114d3adf4407a947f07 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Mon, 15 Aug 2022 17:01:34 +0800 Subject: [PATCH 49/56] fix type signature to pass null check --- .../tools/dotc/core/GadtConstraint.scala | 43 ++++++++++--------- .../dotty/tools/dotc/core/TypeComparer.scala | 4 +- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index c0505b86585d..6b3179a1ca21 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -157,42 +157,43 @@ final class ProperGadtConstraint private( */ def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = def checkSubsumes(left: ProperGadtConstraint, right: ProperGadtConstraint, pre: ProperGadtConstraint): Boolean = { - def rightToLeft: TypeParamRef => TypeParamRef = { + def getRightToLeftMapping: Option[TypeParamRef => TypeParamRef] = { val preParams = pre.constraint.domainParams.toSet val mapping = { var res: SimpleIdentityMap[TypeParamRef, TypeParamRef] = SimpleIdentityMap.empty + var hasNull: Boolean = false right.constraint.domainParams.foreach { p2 => left.tvarOf(right.externalize(p2)) match { case null => + hasNull = true case tv: TypeVar => res = res.updated(p2, tv.origin) } } - res + if hasNull then None else Some(res) } - def func(p2: TypeParamRef) = - if pre.constraint.domainParams contains p2 then p2 - else mapping(p2) - - func - } - - def checkParam(p2: TypeParamRef) = - rightToLeft(p2).match { - case null => false - case p1: TypeParamRef => - left.constraint.entry(p1).exists - && right.constraint.upper(p1).map(rightToLeft).forall(left.constraint.isLess(p1, _)) - && isSubTypeWhenFrozen(left.constraint.nonParamBounds(p1), right.constraint.nonParamBounds(p2)) + mapping map { mapping => + def func(p2: TypeParamRef) = + if pre.constraint.domainParams contains p2 then p2 + else mapping(p2).nn + func } + } - def todos: Set[TypeParamRef] = - right.constraint.domainParams.toSet ++ pre.constraint.domainParams - - todos.forall(checkParam) + getRightToLeftMapping map { rightToLeft => + def checkParam(p2: TypeParamRef) = + val p1 = rightToLeft(p2) + left.constraint.entry(p1).exists + && right.constraint.upper(p1).map(rightToLeft).forall(left.constraint.isLess(p1, _)) + && isSubTypeWhenFrozen(left.constraint.nonParamBounds(p1), right.constraint.nonParamBounds(p2)) + def todos: Set[TypeParamRef] = + right.constraint.domainParams.toSet ++ pre.constraint.domainParams + + todos.forall(checkParam) + } getOrElse false } (left, right, pre) match { @@ -830,7 +831,7 @@ final class ProperGadtConstraint private( override def isEquivalent(p: PathType, q: PathType) = false - override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") + override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") override def symbols: List[Symbol] = Nil diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index e8532310fc6b..70f90b1915e6 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2109,9 +2109,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else rollbackGadtUnless { if isUpper then - gadtAddUpperBound(tparam, bound) + gadtAddBound(tparam, bound, isUpper = true) else - gadtAddLowerBound(tparam, bound) + gadtAddBound(tparam, bound, isUpper = false) } } From 25aa1626433e02a72d5f913fb21b6c70d8235c79 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 14 Sep 2022 16:51:06 +0200 Subject: [PATCH 50/56] refactor path aliasing constraints --- .../tools/dotc/core/GadtConstraint.scala | 46 ++++++++----------- .../dotc/core/PatternTypeConstrainer.scala | 4 +- .../dotty/tools/dotc/core/TypeComparer.scala | 15 +----- 3 files changed, 22 insertions(+), 43 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 6b3179a1ca21..a6d593c51557 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -59,14 +59,11 @@ sealed abstract class GadtConstraint extends Showable { /** Further constrain a path-dependent type already present in the constraint. */ def addBound(p: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean - /** Record the equality between two singleton types. */ - def addEquality(p: PathType, q: PathType)(using Context): Unit + /** Record the aliasing relationship between two singleton types. */ + def recordPathAliasing(p: PathType, q: PathType)(using Context): Unit - /** Check whether two singleton types are equivalent. */ - def isEquivalent(p: PathType, q: PathType): Boolean - - /** Query the representative member of a singleton type. */ - def reprOf(p: PathType): PathType | Null + /** Check whether two paths are equivalent via path aliasing. */ + def isAliasingPath(p: PathType, q: PathType): Boolean /** Scrutinee path of the current pattern matching. */ def scrutineePath: TermRef | Null @@ -124,7 +121,7 @@ final class ProperGadtConstraint private( private var pathDepReverseMapping: SimpleIdentityMap[TypeParamRef, TypeRef], private var wasConstrained: Boolean, private var myScrutineePath: TermRef | Null, - private var myUnionFind: SimpleIdentityMap[PathType, PathType], + private var pathAliasingMapping: SimpleIdentityMap[PathType, PathType], private var myPatternSkolem: SkolemType | Null, ) extends GadtConstraint with ConstraintHandling { import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} @@ -137,7 +134,7 @@ final class ProperGadtConstraint private( pathDepReverseMapping = SimpleIdentityMap.empty, wasConstrained = false, myScrutineePath = null, - myUnionFind = SimpleIdentityMap.empty, + pathAliasingMapping = SimpleIdentityMap.empty, myPatternSkolem = null, ) @@ -392,9 +389,6 @@ final class ProperGadtConstraint private( buf ++= "}" buf.result - /** Get the representative member of the path in the union find. */ - override def reprOf(p: PathType): PathType | Null = lookupPath(p) - override def addToConstraint(params: List[Symbol])(using Context): Boolean = { import NameKinds.DepParamName @@ -512,7 +506,7 @@ final class ProperGadtConstraint private( } private def lookupPath(p: PathType): PathType | Null = - def recur(p: PathType): PathType | Null = myUnionFind(p) match + def recur(p: PathType): PathType | Null = pathAliasingMapping(p) match case null => null case q: PathType if q eq p => q case q: PathType => @@ -520,7 +514,7 @@ final class ProperGadtConstraint private( recur(p) - override def addEquality(p: PathType, q: PathType)(using Context): Unit = + override def recordPathAliasing(p: PathType, q: PathType)(using Context): Unit = val pRep: PathType | Null = lookupPath(p) val qRep: PathType | Null = lookupPath(q) @@ -529,13 +523,13 @@ final class ProperGadtConstraint private( case (null, r: PathType) => r case (r: PathType, null) => r case (r1: PathType, r2: PathType) => - myUnionFind = myUnionFind.updated(r2, r1) + pathAliasingMapping = pathAliasingMapping.updated(r2, r1) r1 - myUnionFind = myUnionFind.updated(p, newRep) - myUnionFind = myUnionFind.updated(q, newRep) + pathAliasingMapping = pathAliasingMapping.updated(p, newRep) + pathAliasingMapping = pathAliasingMapping.updated(q, newRep) - override def isEquivalent(p: PathType, q: PathType): Boolean = + override def isAliasingPath(p: PathType, q: PathType): Boolean = lookupPath(p) match case null => false case p0: PathType => lookupPath(q) match @@ -637,7 +631,7 @@ final class ProperGadtConstraint private( pathDepReverseMapping, wasConstrained, myScrutineePath, - myUnionFind, + pathAliasingMapping, myPatternSkolem, ) @@ -650,7 +644,7 @@ final class ProperGadtConstraint private( this.pathDepReverseMapping = other.pathDepReverseMapping this.wasConstrained = other.wasConstrained this.myScrutineePath = other.myScrutineePath - this.myUnionFind = other.myUnionFind + this.pathAliasingMapping = other.pathAliasingMapping this.myPatternSkolem = other.myPatternSkolem case _ => ; } @@ -675,10 +669,10 @@ final class ProperGadtConstraint private( } def updateUnionFind() = - myUnionFind(myPatternSkolem.nn) match { + pathAliasingMapping(myPatternSkolem.nn) match { case null => case repr: PathType => - myUnionFind = myUnionFind.updated(path, repr) + pathAliasingMapping = pathAliasingMapping.updated(path, repr) } updateMappings() @@ -784,7 +778,7 @@ final class ProperGadtConstraint private( } } sb ++= "\nSingleton equalities:\n" - myUnionFind foreachBinding { case (path, _) => + pathAliasingMapping foreachBinding { case (path, _) => val repr = lookupPath(path) repr match case repr: PathType if repr ne path => @@ -805,8 +799,6 @@ final class ProperGadtConstraint private( override def bounds(tp: TypeRef)(using Context): TypeBounds | Null = null override def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null = null - override def reprOf(p: PathType): PathType | Null = null - override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") override def isLess(tp1: NamedType, tp2: NamedType)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess") @@ -827,9 +819,9 @@ final class ProperGadtConstraint private( override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") override def addBound(path: PathType, sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound") - override def addEquality(p: PathType, q: PathType)(using Context) = () + override def recordPathAliasing(p: PathType, q: PathType)(using Context) = () - override def isEquivalent(p: PathType, q: PathType) = false + override def isAliasingPath(p: PathType, q: PathType) = false override def approximation(sym: Symbol, fromBelow: Boolean, maxLevel: Int)(using Context): Type = unsupported("EmptyGadtConstraint.approximation") diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 4378b89e438f..9fbb7d6de5d8 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -230,7 +230,7 @@ trait PatternTypeConstrainer { self: TypeComparer => /** Reconstruct subtype from the cohabitation between the scrutinee and the pattern. */ def constrainPattern: Boolean = { - ctx.gadt.addEquality(scrutineePath, patternPath) + ctx.gadt.recordPathAliasing(scrutineePath, patternPath) (!registerPattern || reconstructSubTypeFor(patternPath, scrutineePath)) && (!registerScrutinee || reconstructSubTypeFor(scrutineePath, patternPath)) @@ -253,7 +253,7 @@ trait PatternTypeConstrainer { self: TypeComparer => (!registerPtPath || reconstructSubTypeFor(ptPath, scrutineePath)) && (!registerScrutinee || reconstructSubTypeFor(scrutineePath, ptPath)) - ctx.gadt.addEquality(scrutineePath, ptPath) + ctx.gadt.recordPathAliasing(scrutineePath, ptPath) result case _ => diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 70f90b1915e6..3d66057eefa7 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -593,7 +593,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def compareSingletonGADT: Boolean = (tp1, tp2) match { case (tp1: TermRef, tp2: TermRef) => - ctx.gadt.isEquivalent(tp1, tp2) && { GADTused = true; true } + ctx.gadt.isAliasingPath(tp1, tp2) && { GADTused = true; true } case _ => false } @@ -2360,19 +2360,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case Atoms.Range(lo2, hi2) => if hi1.subsetOf(lo2) then return tp2 if hi2.subsetOf(lo1) then return tp1 - - def getReprSet(ps: Set[Type]): Set[Type] = - ps.map { x => - x match - case p: PathType => - val rep = ctx.gadt.reprOf(p) - if rep == null then p else rep - case t => t - } - val (repLo1, repHi1, repLo2, repHi2) = (getReprSet(lo1), getReprSet(hi1), getReprSet(lo2), getReprSet(hi2)) - if repHi2.subsetOf(repLo1) then return tp1 - if repHi1.subsetOf(repLo2) then return tp2 - if (hi1 & hi2).isEmpty then return orType(tp1, tp2) case none => case none => From ed403593de853d3ad0b148ab448feebe530d175c Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 21 Sep 2022 14:57:50 +0200 Subject: [PATCH 51/56] fix typeMemberTouched and add a testcase --- .../dotty/tools/dotc/core/PatternTypeConstrainer.scala | 2 +- tests/pos/pdgadt-path.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 tests/pos/pdgadt-path.scala diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 9fbb7d6de5d8..03d34cfc8839 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -144,7 +144,7 @@ trait PatternTypeConstrainer { self: TypeComparer => ctx.gadt.bounds(scrut.symbol) match { case tb: TypeBounds => val hi = tb.hi - constrainPatternType(pat, hi) + constrainPatternType(pat, hi, typeMembersTouched = true) case null => true } case _ => true diff --git a/tests/pos/pdgadt-path.scala b/tests/pos/pdgadt-path.scala new file mode 100644 index 000000000000..1b9b4e8fc9c9 --- /dev/null +++ b/tests/pos/pdgadt-path.scala @@ -0,0 +1,8 @@ +trait Expr +case class IntLit(b: Int) extends Expr + +def foo[M <: Expr](e: M) = e match { + case e1: IntLit => + val t0: e1.type = e + val t1: e.type = e1 +} From 5118628c783a0ec71b766297c9dfedfdb50cfd36 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 21 Sep 2022 15:05:36 +0200 Subject: [PATCH 52/56] refactor constrainPatternType --- .../dotc/core/PatternTypeConstrainer.scala | 220 +++++++++--------- 1 file changed, 110 insertions(+), 110 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 03d34cfc8839..a301dc7a3ba7 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -74,106 +74,127 @@ trait PatternTypeConstrainer { self: TypeComparer => * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables, * in which case the subtyping relationship "heals" the type. */ - def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false, typeMembersTouched: Boolean = false): Boolean = trace(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { - - def classesMayBeCompatible: Boolean = { - import Flags._ - val patCls = pat.classSymbol - val scrCls = scrut.classSymbol - !patCls.exists || !scrCls.exists || { - if (patCls.is(Final)) patCls.derivesFrom(scrCls) - else if (scrCls.is(Final)) scrCls.derivesFrom(patCls) - else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait)) - patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls) - else true + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + def recur(pat: Type, scrut: Type): Boolean = { + def classesMayBeCompatible: Boolean = { + import Flags._ + val patCls = pat.classSymbol + val scrCls = scrut.classSymbol + !patCls.exists || !scrCls.exists || { + if (patCls.is(Final)) patCls.derivesFrom(scrCls) + else if (scrCls.is(Final)) scrCls.derivesFrom(patCls) + else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait)) + patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls) + else true + } } - } - - def stripRefinement(tp: Type): Type = tp match { - case tp: RefinedOrRecType => stripRefinement(tp.parent) - case tp => tp - } - def tryConstrainSimplePatternType(pat: Type, scrut: Type) = { - val patCls = pat.classSymbol - val scrCls = scrut.classSymbol - patCls.exists && scrCls.exists - && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)) - && constrainSimplePatternType(pat, scrut, forceInvariantRefinement) - } + def stripRefinement(tp: Type): Type = tp match { + case tp: RefinedOrRecType => stripRefinement(tp.parent) + case tp => tp + } - def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) { - // Fold a list of types into an AndType - def buildAndType(xs: List[Type]): Type = { - @annotation.tailrec def recur(acc: Type, rem: List[Type]): Type = rem match { - case Nil => acc - case x :: rem => recur(AndType(acc, x), rem) - } - xs match { - case Nil => NoType - case x :: xs => recur(x, xs) - } + def tryConstrainSimplePatternType(pat: Type, scrut: Type) = { + val patCls = pat.classSymbol + val scrCls = scrut.classSymbol + patCls.exists && scrCls.exists + && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)) + && constrainSimplePatternType(pat, scrut, forceInvariantRefinement) } - scrut match { - case scrut: TypeRef if scrut.symbol.isClass => - // consider all parents - val parents = scrut.parents - val andType = buildAndType(parents) - !andType.exists || constrainPatternType(pat, andType, typeMembersTouched = true) - case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass => - val patCls = pat.classSymbol - // find all shared parents in the inheritance hierarchy between pat and scrut - def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = { - var parents = tpClassSym.info.parents - if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then - parents = parents.tail - parents flatMap { tp => - val sym = tp.classSymbol.asClass - if patCls.derivesFrom(sym) then List(sym) - else allParentsSharedWithPat(tp, sym) - } + def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) { + // Fold a list of types into an AndType + def buildAndType(xs: List[Type]): Type = { + @annotation.tailrec def recur(acc: Type, rem: List[Type]): Type = rem match { + case Nil => acc + case x :: rem => recur(AndType(acc, x), rem) } - val allSyms = allParentsSharedWithPat(tycon, tycon.symbol.asClass) - val baseClasses = allSyms map scrut.baseType - val andType = buildAndType(baseClasses) - !andType.exists || constrainPatternType(pat, andType, typeMembersTouched = true) - case _ => - def tryGadtBounds = scrut match { - case scrut: TypeRef => - ctx.gadt.bounds(scrut.symbol) match { - case tb: TypeBounds => - val hi = tb.hi - constrainPatternType(pat, hi, typeMembersTouched = true) - case null => true - } - case _ => true + xs match { + case Nil => NoType + case x :: xs => recur(x, xs) } + } - def trySuperType = - val upcasted: Type = scrut match { - case scrut: TypeProxy => - scrut.superType - case _ => NoType + scrut match { + case scrut: TypeRef if scrut.symbol.isClass => + // consider all parents + val parents = scrut.parents + val andType = buildAndType(parents) + !andType.exists || recur(pat, andType) + case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass => + val patCls = pat.classSymbol + // find all shared parents in the inheritance hierarchy between pat and scrut + def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = { + var parents = tpClassSym.info.parents + if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then + parents = parents.tail + parents flatMap { tp => + val sym = tp.classSymbol.asClass + if patCls.derivesFrom(sym) then List(sym) + else allParentsSharedWithPat(tp, sym) + } + } + val allSyms = allParentsSharedWithPat(tycon, tycon.symbol.asClass) + val baseClasses = allSyms map scrut.baseType + val andType = buildAndType(baseClasses) + !andType.exists || recur(pat, andType) + case _ => + def tryGadtBounds = scrut match { + case scrut: TypeRef => + ctx.gadt.bounds(scrut.symbol) match { + case tb: TypeBounds => + val hi = tb.hi + recur(pat, hi) + case null => true + } + case _ => true } - if (upcasted.exists) - tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted) - else true - tryGadtBounds && trySuperType + def trySuperType = + val upcasted: Type = scrut match { + case scrut: TypeProxy => + scrut.superType + case _ => NoType + } + if (upcasted.exists) + tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted) + else true + + tryGadtBounds && trySuperType + } + } + + def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match { + case tp: TermRef => + // we drop TermRefs that don't have a class symbol, as they can't + // meaningfully participate in GADT reasoning and just get in the way. + // Their info could, for an example, be an AndType. One example where + // this is important is an enum case that extends its parent and an + // additional trait - argument-less enum cases desugar to vals. + // See run/enum-Tree.scala. + if tp.classSymbol.exists then tp else tp.info + case tp => tp } - } - def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match { - case tp: TermRef => - // we drop TermRefs that don't have a class symbol, as they can't - // meaningfully participate in GADT reasoning and just get in the way. - // Their info could, for an example, be an AndType. One example where - // this is important is an enum case that extends its parent and an - // additional trait - argument-less enum cases desugar to vals. - // See run/enum-Tree.scala. - if tp.classSymbol.exists then tp else tp.info - case tp => tp + dealiasDropNonmoduleRefs(scrut) match { + case OrType(scrut1, scrut2) => + either(recur(pat, scrut1), recur(pat, scrut2)) + case AndType(scrut1, scrut2) => + recur(pat, scrut1) && recur(pat, scrut2) + case scrut: RefinedOrRecType => + recur(pat, stripRefinement(scrut)) + case scrut => dealiasDropNonmoduleRefs(pat) match { + case OrType(pat1, pat2) => + either(recur(pat1, scrut), recur(pat2, scrut)) + case AndType(pat1, pat2) => + recur(pat1, scrut) && recur(pat2, scrut) + case pat: RefinedOrRecType => + recur(stripRefinement(pat), scrut) + case pat => + tryConstrainSimplePatternType(pat, scrut) + || classesMayBeCompatible && constrainUpcasted(scrut) + } + } } /** Reconstruct subtype constraints for type members of the scrutinee and the pattern. */ @@ -269,28 +290,7 @@ trait PatternTypeConstrainer { self: TypeComparer => res } - def constrainTypeParams = - dealiasDropNonmoduleRefs(scrut) match { - case OrType(scrut1, scrut2) => - either(constrainPatternType(pat, scrut1, typeMembersTouched = true), constrainPatternType(pat, scrut2, typeMembersTouched = true)) - case AndType(scrut1, scrut2) => - constrainPatternType(pat, scrut1, typeMembersTouched = true) && constrainPatternType(pat, scrut2, typeMembersTouched = true) - case scrut: RefinedOrRecType => - constrainPatternType(pat, stripRefinement(scrut), typeMembersTouched = true) - case scrut => dealiasDropNonmoduleRefs(pat) match { - case OrType(pat1, pat2) => - either(constrainPatternType(pat1, scrut, typeMembersTouched = true), constrainPatternType(pat2, scrut, typeMembersTouched = true)) - case AndType(pat1, pat2) => - constrainPatternType(pat1, scrut, typeMembersTouched = true) && constrainPatternType(pat2, scrut, typeMembersTouched = true) - case pat: RefinedOrRecType => - constrainPatternType(stripRefinement(pat), scrut, typeMembersTouched = true) - case pat => - tryConstrainSimplePatternType(pat, scrut) - || classesMayBeCompatible && constrainUpcasted(scrut) - } - } - - constrainTypeParams && (typeMembersTouched || constrainTypeMembers) + recur(pat, scrut) && constrainTypeMembers } /** Show the scrutinee. Will show the path if available. */ From e043fdd14979b0503794a378e94b45c8cf0c3e9c Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 21 Sep 2022 16:18:54 +0200 Subject: [PATCH 53/56] documenting SR for path-dependent types --- .../dotc/core/PatternTypeConstrainer.scala | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index a301dc7a3ba7..c9ea9b1e075b 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -197,7 +197,35 @@ trait PatternTypeConstrainer { self: TypeComparer => } } - /** Reconstruct subtype constraints for type members of the scrutinee and the pattern. */ + /** Reconstruct subtype constraints for type members of the scrutinee and the pattern. + * + * To inference SR constraints for the type members from the scrutinee `p` and the pattern `q`, + * we first find all the abstract type members of `p`: A₁, A₂, ⋯, Aₖ. + * + * Then, for each Aᵢ, if `q` also has a type member labaled `Aᵢ`, we inference SR constraints by calling + * TypeComparer on the relation `p.Aᵢ <:< q.Aᵢ`. + * We derive SR constraints for type members of the pattern path `q` similarly. + * + * Specially, if for some `Aᵢ`, `p.Aᵢ` is abstract while `q.Aᵢ` is not, we will extract constraints + * for both directions of the subtype relations (i.e. both `p.Aᵢ <:< q.Aᵢ` and `q.Aᵢ <:< p.Aᵢ`). + * + * How we find out and handle the path (`TermRef`) of the scrutinee and pattern. + * + * - The path of scrutinee is not directly available in `constrainPatternType`, since the scrutinee type + * passed to this function is widened. + * To make the path available during GADT reasoning, we save the scrutinee path in `Typer.typedCase`. The scrutinee path will be saved in `ctx.gadt.scrutineePath`. + * Note that we have to clear the saved scrutinee path after using by calling `ctx.gadt.resetScrutineePath()`. + * This is because `constrainPatternType` may be called multiple times for one nested pattern. For example: + * + * e match + * case A(B(a), C(b)) => // ... + * + * We have to reset the scrutinee path after constraining `e` against the top level pattern `A(...)`. + * + * - The path of pattern is not available when calling the function, and the symbol of the pattern will only be created after GADT reasoning. + * Therefore, we will create `SkolemType` acting as a placeholder for the pattern path, and substitute it with the real pattern path + * when it is available later in the typer. We call `ctx.gadt.supplyPatternPath` to do the substitution. + */ def constrainTypeMembers = trace(i"constrainTypeMembers(${scrutRepr(scrut)}, $pat)", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { import NameKinds.DepParamName val realScrutineePath = ctx.gadt.scrutineePath From 898c47dfaf8fa7f3c72058e466f659106f7560d0 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Wed, 21 Sep 2022 17:00:56 +0200 Subject: [PATCH 54/56] support GADT reasoning on aliases for nested pattern --- .../src/dotty/tools/dotc/core/GadtConstraint.scala | 8 ++++++-- compiler/src/dotty/tools/dotc/typer/Typer.scala | 8 ++++---- tests/neg/pdgadt-nested-pat-alias.scala | 12 ++++++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 tests/neg/pdgadt-nested-pat-alias.scala diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index a6d593c51557..34920dfb1861 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -677,11 +677,15 @@ final class ProperGadtConstraint private( updateMappings() updateUnionFind() + myPatternSkolem = null end supplyPatternPath override def createPatternSkolem(pat: Type): SkolemType = - myPatternSkolem = SkolemType(pat) - myPatternSkolem.nn + if myPatternSkolem ne null then + SkolemType(pat) + else + myPatternSkolem = SkolemType(pat) + myPatternSkolem.nn end createPatternSkolem override def withScrutineePath[T](path: TermRef | Null)(op: => T): T = diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 6175d1b66d80..65e7a370446e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1754,10 +1754,10 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer sel.tpe match { case p: TermRef => tree.pat match { - case _: Trees.Typed[_] => p - case _: Trees.Ident[_] => p - case _: Trees.Apply[_] => p - case _ => null + case _: (Trees.Typed[_] | Trees.Ident[_] | Trees.Apply[_] | Trees.Bind[_]) => + p + case _ => + null } case _ => null } diff --git a/tests/neg/pdgadt-nested-pat-alias.scala b/tests/neg/pdgadt-nested-pat-alias.scala new file mode 100644 index 000000000000..3c33a367f40b --- /dev/null +++ b/tests/neg/pdgadt-nested-pat-alias.scala @@ -0,0 +1,12 @@ +trait Expr { type T } +case class B(b: Int) extends Expr { type T = Int } +case class C(c: Boolean) extends Expr { type T = Boolean } +case class A(a: Expr, b: Expr) extends Expr { type T = (a.T, b.T) } + +def foo(e: Expr) = e match + case e1 @ A(b1 @ B(b), c1 @ C(c)) => + val t1: e1.type = e + val t2: e.type = e1 + + val t3: b1.type = e // error + val t4: c1.type = e // error From 7d5046ff97d807e7d5e1d28e44250c2b6fbc5fb7 Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 22 Sep 2022 00:47:31 +0200 Subject: [PATCH 55/56] add tons of documentation --- .../tools/dotc/core/GadtConstraint.scala | 21 ++++++++++++-- .../dotc/core/PatternTypeConstrainer.scala | 28 ++++++++++--------- .../src/dotty/tools/dotc/typer/Typer.scala | 18 ++++++++++++ 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 34920dfb1861..f6fdd5eff28b 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -25,6 +25,8 @@ sealed abstract class GadtConstraint extends Showable { /** Immediate bounds of a path-dependent type. * This variant of bounds will ONLY try to retrieve path-dependent GADT bounds. */ def bounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null + + /** Immediate bounds of path-dependent type tp. */ def bounds(tp: TypeRef)(using Context): TypeBounds | Null /** Full bounds of `sym`, including TypeRefs to other lower/upper symbols. @@ -34,7 +36,10 @@ sealed abstract class GadtConstraint extends Showable { */ def fullBounds(sym: Symbol)(using Context): TypeBounds | Null + /** Full bounds of path dependent type path.sym. */ def fullBounds(path: PathType, sym: Symbol)(using Context): TypeBounds | Null + + /** Full bounds of path-dependent type tp. */ def fullBounds(tp: TypeRef)(using Context): TypeBounds | Null /** Is `sym1` ordered to be less than `sym2`? */ @@ -65,7 +70,10 @@ sealed abstract class GadtConstraint extends Showable { /** Check whether two paths are equivalent via path aliasing. */ def isAliasingPath(p: PathType, q: PathType): Boolean - /** Scrutinee path of the current pattern matching. */ + /** Scrutinee path of the current pattern matching that is being typed. + * + * See `constrainTypeMembers` in `PatternTypeConstrainer`. + */ def scrutineePath: TermRef | Null /** Reset scrutinee path to null. */ @@ -74,10 +82,16 @@ sealed abstract class GadtConstraint extends Showable { /** Set the scrutinee path. */ def withScrutineePath[T](path: TermRef | Null)(op: => T): T - /** Supply the real pattern path. */ + /** Supply the real pattern path. + * + * See `constrainTypeMembers` in `PatternTypeConstrainer`. + */ def supplyPatternPath(path: TermRef)(using Context): Unit - /** Create a skolem type for pattern. */ + /** Create a skolem type for pattern and save it in the constraint handler. + * + * See `constrainTypeMembers` in `PatternTypeConstrainer`. + */ def createPatternSkolem(pat: Type): SkolemType /** Is the symbol registered in the constraint? @@ -95,6 +109,7 @@ sealed abstract class GadtConstraint extends Showable { /** Checks whether a given path-dependent type is constrainable. */ def isConstrainablePDT(path: PathType, sym: Symbol)(using Context): Boolean + /** Get all type members registered in the constraint handler for this path. */ def registeredTypeMembers(path: PathType): List[Symbol] /** GADT constraint narrows bounds of at least one variable */ diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index c9ea9b1e075b..9dba2c81fac0 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -201,6 +201,7 @@ trait PatternTypeConstrainer { self: TypeComparer => * * To inference SR constraints for the type members from the scrutinee `p` and the pattern `q`, * we first find all the abstract type members of `p`: A₁, A₂, ⋯, Aₖ. + * If these path-dependent types are not registered in the handler, we will register them. * * Then, for each Aᵢ, if `q` also has a type member labaled `Aᵢ`, we inference SR constraints by calling * TypeComparer on the relation `p.Aᵢ <:< q.Aᵢ`. @@ -209,12 +210,12 @@ trait PatternTypeConstrainer { self: TypeComparer => * Specially, if for some `Aᵢ`, `p.Aᵢ` is abstract while `q.Aᵢ` is not, we will extract constraints * for both directions of the subtype relations (i.e. both `p.Aᵢ <:< q.Aᵢ` and `q.Aᵢ <:< p.Aᵢ`). * - * How we find out and handle the path (`TermRef`) of the scrutinee and pattern. + * How we find out and handle the path (`TermRef`) of the scrutinee and pattern: * - * - The path of scrutinee is not directly available in `constrainPatternType`, since the scrutinee type - * passed to this function is widened. - * To make the path available during GADT reasoning, we save the scrutinee path in `Typer.typedCase`. The scrutinee path will be saved in `ctx.gadt.scrutineePath`. - * Note that we have to clear the saved scrutinee path after using by calling `ctx.gadt.resetScrutineePath()`. + * - The path of scrutinee is not directly available in `constrainPatternType`, since the scrutinee type passed to this function is widened. + * To have access to the scrutinee path here, we save the scrutinee path in `Typer.typedCase` with `GadtConstraint.withScrutineePath`, + * and the scrutinee path will be accessible as `ctx.gadt.scrutineePath`. + * Note that we have to reset the saved scrutinee path to `null` after using by calling `ctx.gadt.resetScrutineePath()`. * This is because `constrainPatternType` may be called multiple times for one nested pattern. For example: * * e match @@ -230,7 +231,7 @@ trait PatternTypeConstrainer { self: TypeComparer => import NameKinds.DepParamName val realScrutineePath = ctx.gadt.scrutineePath - /* We reset scrutinee path so that the path will only be used at top level. */ + // We reset scrutinee path so that the path will only be used at top level. ctx.gadt.resetScrutineePath() val saved = state.nn.constraint @@ -242,8 +243,9 @@ trait PatternTypeConstrainer { self: TypeComparer => val patternPath: SkolemType = ctx.gadt.createPatternSkolem(pat) val registerScrutinee = ctx.gadt.contains(scrutineePath) || ctx.gadt.addToConstraint(scrutineePath) - val registerPattern = ctx.gadt.addToConstraint(patternPath) // Pattern path is a freshly-created skolem, - // so it will always be un-registered at this point + // Pattern path is a freshly-created skolem, + // so it will always be un-registered at this point + val registerPattern = ctx.gadt.addToConstraint(patternPath) /** Reconstruct subtype constraints for a path `p`, given that `p` and `q` are cohabitated. @@ -260,7 +262,7 @@ trait PatternTypeConstrainer { self: TypeComparer => (3) q.T is unregistered. We will do SR on p.T <:< q.T and q.T <:< p.T. */ - def reconstructSubTypeFor(p: PathType, q: PathType) = + def reconstructSubType(p: PathType, q: PathType) = def processMember(sym: Symbol): Boolean = q.member(sym.name).isInstanceOf[NoDenotation.type] || { val pType = TypeRef(p, sym) @@ -281,8 +283,8 @@ trait PatternTypeConstrainer { self: TypeComparer => def constrainPattern: Boolean = { ctx.gadt.recordPathAliasing(scrutineePath, patternPath) - (!registerPattern || reconstructSubTypeFor(patternPath, scrutineePath)) - && (!registerScrutinee || reconstructSubTypeFor(scrutineePath, patternPath)) + (!registerPattern || reconstructSubType(patternPath, scrutineePath)) + && (!registerScrutinee || reconstructSubType(scrutineePath, patternPath)) } /** Reconstruct subtype when the pattern is an alias to another path. @@ -299,8 +301,8 @@ trait PatternTypeConstrainer { self: TypeComparer => val registerPtPath = ctx.gadt.contains(ptPath) || ctx.gadt.addToConstraint(ptPath) val result = - (!registerPtPath || reconstructSubTypeFor(ptPath, scrutineePath)) - && (!registerScrutinee || reconstructSubTypeFor(scrutineePath, ptPath)) + (!registerPtPath || reconstructSubType(ptPath, scrutineePath)) + && (!registerScrutinee || reconstructSubType(scrutineePath, ptPath)) ctx.gadt.recordPathAliasing(scrutineePath, ptPath) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 65e7a370446e..65bc798865ec 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1755,6 +1755,19 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case p: TermRef => tree.pat match { case _: (Trees.Typed[_] | Trees.Ident[_] | Trees.Apply[_] | Trees.Bind[_]) => + // We only record scrutinee path in the above cases, b/c recording + // it in all cases may lead to unsoundness. + // + // For example: + // + // def foo(e: (Expr, Expr)) = e match + // case (e1: Expr, e2: Expr) => + // + // Here the pattern is a tuple. `constrainPatternType` will be called + // on the two elements of the tuple directly, without constraining + // `e` and the whole tuple first. + // Therefore, recording the scrutinee path in this case can give + // us constraints like `e1.type == e.type`, which is not true. p case _ => null @@ -1762,11 +1775,16 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => null } + // Save the scrutinee path and then type the pattern. + // The scrutinee path will be used in SR reasoning for path-dependent types. + // See `constrainTypeMembers` in `PatternTypeConstrainer`. val pat1 = gadtCtx.gadt.withScrutineePath(scrutineePath) { typedPattern(tree.pat, wideSelType)(using gadtCtx) } if scrutineePath.ne(null) && pat1.symbol.isPatternBound then + // Subtitute the place holder with real pattern path in GADT constraints. + // See `constrainTypeMembers` in `PatternTypeConstrainer`. gadtCtx.gadt.supplyPatternPath(pat1.symbol.termRef) caseRest(pat1)( From 7adfd73c474d76ea28271243cc317a00a1b78a4e Mon Sep 17 00:00:00 2001 From: Yichen Xu Date: Thu, 22 Sep 2022 01:26:34 +0200 Subject: [PATCH 56/56] add one testcase adapted from #15958 --- tests/neg/i15958.scala | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/neg/i15958.scala diff --git a/tests/neg/i15958.scala b/tests/neg/i15958.scala new file mode 100644 index 000000000000..b5a52b01e618 --- /dev/null +++ b/tests/neg/i15958.scala @@ -0,0 +1,29 @@ +sealed trait NatT { type This <: NatT } +case class Zero() extends NatT { + type This = Zero +} +case class Succ[N <: NatT](n: N) extends NatT { + type This = Succ[n.This] +} + +trait IsLessThan[+M <: NatT, N <: NatT] +object IsLessThan: + given base[M <: NatT]: IsLessThan[M, Succ[M]]() + given weakening[N <: NatT, M <: NatT] (using IsLessThan[N, M]): IsLessThan[N, Succ[M]]() + given reduction[N <: NatT, M <: NatT] (using IsLessThan[Succ[N], Succ[M]]): IsLessThan[N, M]() + +sealed trait UniformTuple[Length <: NatT, T]: + def apply[M <: NatT](m: M)(using IsLessThan[m.This, Length]): T + +case class Empty[T]() extends UniformTuple[Zero, T]: + def apply[M <: NatT](m: M)(using IsLessThan[m.This, Zero]): T = throw new AssertionError("Uncallable") + +case class Cons[N <: NatT, T](head: T, tail: UniformTuple[N, T]) extends UniformTuple[Succ[N], T]: + def apply[M <: NatT](m: M)(using proof: IsLessThan[m.This, Succ[N]]): T = m match + case Zero() => head + case m1: Succ[predM] => + val proof1: IsLessThan[m1.This, Succ[N]] = proof + + val res0 = tail(m1.n)(using IsLessThan.reduction(using proof)) // error // limitation + val res1 = tail(m1.n)(using IsLessThan.reduction(using proof1)) + res1