Skip to content

Adding support of path-dependent GADT reasoning #14754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 56 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
9104c23
save path for scrutinee
Linyxus Mar 16, 2022
974e67a
update
Linyxus Mar 16, 2022
7e655d0
checkpoint: before actually narrowing GADT bounds
Linyxus Aug 9, 2022
3abf4d6
path-dependent GADT reasoning for type members
Linyxus Mar 18, 2022
ff0ac70
add GADT reasoning for path-dependent types
Linyxus Mar 20, 2022
0c63ba0
drop typevars in type member bounds
Linyxus Mar 20, 2022
f8d410c
support path-dependent GADT for HKTs
Linyxus Mar 20, 2022
646e2ab
cleanup tracing
Linyxus Mar 23, 2022
34661c4
filtering out NoSymbol to avoid unsound bounds
Linyxus Mar 23, 2022
69bea1c
return true if the type members can not be constrained
Linyxus Mar 23, 2022
0fe9181
fix typo
Linyxus Mar 23, 2022
d3eade0
remove limitation errors in structural gadt neg tests
Linyxus Mar 23, 2022
c3795fe
fix code to pass explicit null check
Linyxus Mar 23, 2022
284ccf1
format and cleanup
Linyxus Jul 15, 2022
991c4cc
remove limitation errors in testcases
Linyxus Jul 15, 2022
9862eb2
record pattern path
Linyxus Jul 15, 2022
71dbe8c
changing externalize to protected
Linyxus Jul 20, 2022
1177cf6
remove workaround for type variable
Linyxus Jul 22, 2022
198287e
only constrain non-private type members
Linyxus Jul 22, 2022
4eac9a4
avoid stripping type variables in OrderingConstraint.replace
Linyxus Jul 23, 2022
8972fe9
more path-dependent GADT examples
Linyxus Jul 23, 2022
ab7bb6c
remove test since it fails -Ycheck
Linyxus Jul 24, 2022
5c57bad
also rollback path-dependent GADT constraints on failure
Linyxus Aug 9, 2022
d092969
Revert "avoid stripping type variables in OrderingConstraint.replace"
Linyxus Aug 9, 2022
c01ff7f
not registering aliasing type members
Linyxus Aug 10, 2022
05ea709
update bounds desc
Linyxus Aug 10, 2022
7b99da4
use isSubType to do subtype reconstruction
Linyxus Aug 10, 2022
bd50e79
avoid false constraints in type inference when GADT mode is on
Linyxus Aug 10, 2022
a406519
fix constrained type lookup in GadtConstraint
Linyxus Aug 10, 2022
0bd1a90
add tracing for addConstraint in TypeComparer
Linyxus Aug 10, 2022
e8d80b6
refactor pathdep GADT constraining logic
Linyxus Aug 10, 2022
fed4d48
update pdgadt-wildcard test
Linyxus Aug 10, 2022
4bab5f0
Revert "add tracing for addConstraint in TypeComparer"
Linyxus Aug 10, 2022
2427004
cleanup and tweak tracing
Linyxus Aug 10, 2022
d9dd8ce
make singleton equality constriants actually working
Linyxus Aug 10, 2022
b50800e
add examples for singleton equality constraints
Linyxus Aug 10, 2022
a26e3e6
add GADT usage info when singleton equality constraints are used
Linyxus Aug 10, 2022
7e5618e
support subtype reconstruction when pattern is an alias to another path
Linyxus Aug 10, 2022
cab8dcb
add test for constraining aliasing pattern
Linyxus Aug 10, 2022
130ee6c
cleanup comments
Linyxus Aug 10, 2022
7136b40
improve the documentation of `GadtConstraint.subsumes`
Linyxus Aug 14, 2022
e2d5f6f
more documentation for GadtConstraint class
Linyxus Aug 14, 2022
8371bad
Minor
Linyxus Aug 14, 2022
a1d0e49
also try to register path-dependent types in the bounds
Linyxus Aug 14, 2022
ac4f6e7
improve tracing in GadtConstraint
Linyxus Aug 14, 2022
6843cb1
substitute instantiated dependent params
Linyxus Aug 15, 2022
3edd4a0
add pdgadt-sub pos test
Linyxus Aug 15, 2022
6908bfc
refactor subsumes
Linyxus Aug 15, 2022
a6f3a3a
fix type signature to pass null check
Linyxus Aug 15, 2022
25aa162
refactor path aliasing constraints
Linyxus Sep 14, 2022
ed40359
fix typeMemberTouched and add a testcase
Linyxus Sep 21, 2022
5118628
refactor constrainPatternType
Linyxus Sep 21, 2022
e043fdd
documenting SR for path-dependent types
Linyxus Sep 21, 2022
898c47d
support GADT reasoning on aliases for nested pattern
Linyxus Sep 21, 2022
7d5046f
add tons of documentation
Linyxus Sep 21, 2022
7adfd73
add one testcase adapted from #15958
Linyxus Sep 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
587 changes: 561 additions & 26 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Large diffs are not rendered by default.

341 changes: 237 additions & 104 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Large diffs are not rendered by default.

195 changes: 161 additions & 34 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,22 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
private def isBottom(tp: Type) = tp.widen.isRef(NothingClass)

protected def gadtBounds(sym: Symbol)(using Context) = ctx.gadt.bounds(sym)
protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper)

/** This variant of gadtBounds works for all named types.
* It queries GADT bounds for both type parameters and path-dependent types.
*/
protected def gadtBounds(tp: NamedType)(using Context) =
ctx.gadt.bounds(tp.symbol) match
case null =>
tp match
case TypeRef(p: PathType, _) => ctx.gadt.bounds(p, tp.symbol)
case _ => null
case tb => tb

protected def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean = ctx.gadt.addBound(sym, b, isUpper = isUpper)

protected def gadtAddLowerBound(path: PathType, sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(path, sym, b, isUpper = false)
protected def gadtAddUpperBound(path: PathType, sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(path, sym, b, isUpper = true)

protected def typeVarInstance(tvar: TypeVar)(using Context): Type = tvar.underlying

Expand Down Expand Up @@ -193,6 +208,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
val bounds = gadtBounds(sym)
bounds != null && op(bounds)

extension (tp: NamedType)
private inline def onGadtBounds(inline op: TypeBounds => Boolean): Boolean =
val bounds = gadtBounds(tp)
bounds != null && op(bounds)

private inline def comparingTypeLambdas(tl1: TypeLambda, tl2: TypeLambda)(op: => Boolean): Boolean =
val saved = comparedTypeLambdas
comparedTypeLambdas += tl1
Expand Down Expand Up @@ -426,7 +446,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case _ => false
} ||
isSubTypeWhenFrozen(bounds(tp1).hi.boxed, tp2) || {
if (canConstrain(tp1) && !approx.high)
if canConstrain(tp1) && isPreciseBound(fromBelow = false) then
addConstraint(tp1, tp2, fromBelow = false) && flagNothingBound
else thirdTry
}
Expand Down Expand Up @@ -540,19 +560,28 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match {
case info2: TypeBounds =>
/** Checks whether tp1 is registered.
* Both type parameters and path-dependent types are considered.
*/
def tpRegistered(tp: TypeRef) = ctx.gadt.contains(tp.symbol) || {
tp match
case tp @ TypeRef(p: PathType, _) => ctx.gadt.contains(p, tp.symbol)
case _ => false
}

def compareGADT: Boolean =
tp2.symbol.onGadtBounds(gbounds2 =>
isSubTypeWhenFrozen(tp1, gbounds2.lo)
|| tp1.match
case tp1: NamedType if ctx.gadt.contains(tp1.symbol) =>
// Note: since we approximate constrained types only with their non-param bounds,
// we need to manually handle the case when we're comparing two constrained types,
// one of which is constrained to be a subtype of another.
// We do not need similar code in fourthTry, since we only need to care about
// comparing two constrained types, and that case will be handled here first.
ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
case _ => false
|| narrowGADTBounds(tp2, tp1, approx, isUpper = false))
{ tp2.onGadtBounds(gbounds2 =>
{ isSubTypeWhenFrozen(tp1, gbounds2.lo) }
|| tp1.match
case tp1: TypeRef if tpRegistered(tp1) =>
// Note: since we approximate constrained types only with their non-param bounds,
// we need to manually handle the case when we're comparing two constrained types,
// one of which is constrained to be a subtype of another.
// We do not need similar code in fourthTry, since we only need to care about
// comparing two constrained types, and that case will be handled here first.
{ ctx.gadt.isLess(tp1, tp2) } && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
case _ => false)
|| narrowGADTBounds(tp2, tp1, approx, isUpper = false) }
&& (isBottom(tp1) || GADTusage(tp2.symbol))

isSubApproxHi(tp1, info2.lo.boxedIfTypeParam(tp2.symbol)) && (trustBounds || isSubApproxHi(tp1, info2.hi))
Expand All @@ -561,6 +590,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
|| fourthTry

case _ =>
def compareSingletonGADT: Boolean =
(tp1, tp2) match {
case (tp1: TermRef, tp2: TermRef) =>
ctx.gadt.isAliasingPath(tp1, tp2) && { GADTused = true; true }
case _ => false
}

val cls2 = tp2.symbol
if (cls2.isClass)
if (cls2.typeParams.isEmpty) {
Expand All @@ -581,9 +617,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
}
else if tp1.isLambdaSub && !tp1.isAnyKind then
return recur(tp1, EtaExpansion(tp2))

if compareSingletonGADT then return true

fourthTry
}

def isPreciseBound(fromBelow: Boolean): Boolean =
if ctx.mode.is(Mode.GadtConstraintInference) then
!(approx.low || approx.high)
else
if fromBelow then !approx.low else !approx.high

def compareTypeParamRef(tp2: TypeParamRef): Boolean =
assumedTrue(tp2) || {
val alwaysTrue =
Expand All @@ -597,7 +642,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
if (frozenConstraint) recur(tp1, bounds(tp2).lo.boxed)
else isSubTypeWhenFrozen(tp1, tp2)
alwaysTrue || {
if (canConstrain(tp2) && !approx.low)
if canConstrain(tp2) && isPreciseBound(fromBelow = true) then
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
else fourthTry
}
Expand Down Expand Up @@ -858,10 +903,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
tp1.info match {
case info1 @ TypeBounds(lo1, hi1) =>
def compareGADT =
tp1.symbol.onGadtBounds(gbounds1 =>
isSubTypeWhenFrozen(gbounds1.hi, tp2)
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
&& (tp2.isAny || GADTusage(tp1.symbol))
{ tp1.onGadtBounds(gbounds1 =>
isSubTypeWhenFrozen(gbounds1.hi, tp2))
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true)
} && (tp2.isAny || GADTusage(tp1.symbol))

(!caseLambda.exists || canWidenAbstract)
&& isSubType(hi1.boxedIfTypeParam(tp1.symbol), tp2, approx.addLow) && (trustBounds || isSubType(lo1, tp2, approx.addLow))
Expand Down Expand Up @@ -1179,22 +1224,30 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
var touchedGADTs = false
var gadtIsInstantiated = false

extension (sym: Symbol)
extension (tp: TypeRef)
inline def byGadtBounds(inline op: TypeBounds => Boolean): Boolean =
touchedGADTs = true
sym.onGadtBounds(
tp.onGadtBounds(
b => op(b) && { gadtIsInstantiated = b.isInstanceOf[TypeAlias]; true })

def byGadtOrdering: Boolean =
ctx.gadt.contains(tycon1sym)
&& ctx.gadt.contains(tycon2sym)
&& ctx.gadt.isLess(tycon1sym, tycon2sym)

def byPathDepGadtOrdering: Boolean =
(tycon1, tycon2) match
case (TypeRef(p1: PathType, _), TypeRef(p2: PathType, _)) =>
ctx.gadt.contains(p1, tycon1sym)
&& ctx.gadt.contains(p2, tycon2sym)
&& ctx.gadt.isLess(tycon1, tycon2)
case _ => false

val res = (
tycon1sym == tycon2sym && isSubPrefix(tycon1.prefix, tycon2.prefix)
|| tycon1sym.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2))
|| tycon2sym.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo))
|| byGadtOrdering
|| tycon1.byGadtBounds(b => isSubTypeWhenFrozen(b.hi, tycon2))
|| tycon2.byGadtBounds(b => isSubTypeWhenFrozen(tycon1, b.lo))
|| byGadtOrdering || byPathDepGadtOrdering
) && {
// There are two cases in which we can assume injectivity.
// First we check if either sym is a class.
Expand Down Expand Up @@ -1843,10 +1896,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
ctx.gadt.restore(preGadt)
if op2 then
if allSubsumes(op1Gadt, ctx.gadt, op1Constraint, constraint) then
gadts.println(i"GADT CUT - prefer ${ctx.gadt} over $op1Gadt")
gadts.println(i"GADT CUT - prefer op2 ${ctx.gadt} over $op1Gadt")
constr.println(i"CUT - prefer $constraint over $op1Constraint")
else if allSubsumes(ctx.gadt, op1Gadt, constraint, op1Constraint) then
gadts.println(i"GADT CUT - prefer $op1Gadt over ${ctx.gadt}")
gadts.println(i"GADT CUT - prefer op1 $op1Gadt over ${ctx.gadt}")
constr.println(i"CUT - prefer $op1Constraint over $constraint")
constraint = op1Constraint
ctx.gadt.restore(op1Gadt)
Expand Down Expand Up @@ -2017,21 +2070,80 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case _ => proto.isMatchedBy(tp, keepConstraint = true)
}

private def rollbackGadtUnless(op: => Boolean): Boolean =
val savedGadt = ctx.gadt.fresh
val result = op
if !result then ctx.gadt.restore(savedGadt)
result
end rollbackGadtUnless

/** Narrow gadt.bounds for the type parameter referenced by `tr` to include
* `bound` as an upper or lower bound (which depends on `isUpper`).
* Test that the resulting bounds are still satisfiable.
*/
private def narrowGADTBounds(tr: NamedType, bound: Type, approx: ApproxState, isUpper: Boolean): Boolean = {
val boundImprecise = approx.high || approx.low
ctx.mode.is(Mode.GadtConstraintInference) && !frozenGadt && !frozenConstraint && !boundImprecise && {
val tparam = tr.symbol
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
else
val savedGadt = ctx.gadt.fresh
val success = gadtAddBound(tparam, bound, isUpper)
if !success then ctx.gadt.restore(savedGadt)
success
def tryRegisterBound: Boolean = bound.match {
case tr @ TypeRef(path: PathType, _) =>
val sym = tr.symbol

def register =
ctx.gadt.contains(path, sym) || ctx.gadt.contains(sym) || {
ctx.gadt.isConstrainablePDT(path, tr.symbol) && {
gadts.println(i"!!! registering path on the fly path=$path sym=$sym")
ctx.gadt.addToConstraint(path) && ctx.gadt.contains(path, sym)
}
}

val result = register

true
case _ => true
}

def narrowTypeParams = ctx.gadt.contains(tr.symbol) && {
val tparam = tr.symbol
gadts.println(i"narrow gadt bound of tparam $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
else
rollbackGadtUnless {
if isUpper then
gadtAddBound(tparam, bound, isUpper = true)
else
gadtAddBound(tparam, bound, isUpper = false)
}
}

def narrowPathDepType = tr match
case TypeRef(path: PathType, _) =>
val sym = tr.symbol

def isConstrainable: Boolean =
ctx.gadt.contains(path, sym) || {
ctx.gadt.isConstrainablePDT(path, tr.symbol) && {
gadts.println(i"!!! registering path on the fly path=$path sym=$sym")
ctx.gadt.addToConstraint(path) && ctx.gadt.contains(path, sym)
}
}

def isRef: Boolean = bound match
case TypeRef(q: PathType, _) => (path eq q) && bound.isRef(sym)
case _ => false

rollbackGadtUnless {
isConstrainable && {
gadts.println(i"narrow gadt bound of pdt $path -> ${sym}: from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${isRef}")

if isRef then false
else if isUpper then gadtAddUpperBound(path, sym, bound)
else gadtAddLowerBound(path, sym, bound)
}
}

case _ => false

tryRegisterBound && narrowTypeParams || narrowPathDepType
}
}

Expand Down Expand Up @@ -3047,6 +3159,21 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
if (sym.exists) footprint += sym.typeRef
super.gadtAddBound(sym, b, isUpper)

override def gadtBounds(tp: NamedType)(using Context): TypeBounds | Null = {
if (tp.symbol.exists) footprint += tp
super.gadtBounds(tp)
}

override def gadtAddLowerBound(path: PathType, sym: Symbol, b: Type): Boolean = {
if (sym.exists) footprint += TypeRef(path, sym)
super.gadtAddLowerBound(path, sym, b)
}

override def gadtAddUpperBound(path: PathType, sym: Symbol, b: Type): Boolean = {
if (sym.exists) footprint += TypeRef(path, sym)
super.gadtAddUpperBound(path, sym, b)
}

override def typeVarInstance(tvar: TypeVar)(using Context): Type = {
footprint += tvar
super.typeVarInstance(tvar)
Expand Down
38 changes: 37 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,43 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1)
}

val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx)
val scrutineePath =
sel.tpe match {
case p: TermRef =>
tree.pat match {
case _: (Trees.Typed[_] | Trees.Ident[_] | Trees.Apply[_] | Trees.Bind[_]) =>
// We only record scrutinee path in the above cases, b/c recording
// it in all cases may lead to unsoundness.
//
// For example:
//
// def foo(e: (Expr, Expr)) = e match
// case (e1: Expr, e2: Expr) =>
//
// Here the pattern is a tuple. `constrainPatternType` will be called
// on the two elements of the tuple directly, without constraining
// `e` and the whole tuple first.
// Therefore, recording the scrutinee path in this case can give
// us constraints like `e1.type == e.type`, which is not true.
p
case _ =>
null
}
case _ => null
}

// Save the scrutinee path and then type the pattern.
// The scrutinee path will be used in SR reasoning for path-dependent types.
// See `constrainTypeMembers` in `PatternTypeConstrainer`.
val pat1 = gadtCtx.gadt.withScrutineePath(scrutineePath) {
typedPattern(tree.pat, wideSelType)(using gadtCtx)
}

if scrutineePath.ne(null) && pat1.symbol.isPatternBound then
// Subtitute the place holder with real pattern path in GADT constraints.
// See `constrainTypeMembers` in `PatternTypeConstrainer`.
gadtCtx.gadt.supplyPatternPath(pat1.symbol.termRef)

caseRest(pat1)(
using Nullables.caseContext(sel, pat1)(
using gadtCtx))
Expand Down
29 changes: 29 additions & 0 deletions tests/neg/i15958.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
sealed trait NatT { type This <: NatT }
case class Zero() extends NatT {
type This = Zero
}
case class Succ[N <: NatT](n: N) extends NatT {
type This = Succ[n.This]
}

trait IsLessThan[+M <: NatT, N <: NatT]
object IsLessThan:
given base[M <: NatT]: IsLessThan[M, Succ[M]]()
given weakening[N <: NatT, M <: NatT] (using IsLessThan[N, M]): IsLessThan[N, Succ[M]]()
given reduction[N <: NatT, M <: NatT] (using IsLessThan[Succ[N], Succ[M]]): IsLessThan[N, M]()

sealed trait UniformTuple[Length <: NatT, T]:
def apply[M <: NatT](m: M)(using IsLessThan[m.This, Length]): T

case class Empty[T]() extends UniformTuple[Zero, T]:
def apply[M <: NatT](m: M)(using IsLessThan[m.This, Zero]): T = throw new AssertionError("Uncallable")

case class Cons[N <: NatT, T](head: T, tail: UniformTuple[N, T]) extends UniformTuple[Succ[N], T]:
def apply[M <: NatT](m: M)(using proof: IsLessThan[m.This, Succ[N]]): T = m match
case Zero() => head
case m1: Succ[predM] =>
val proof1: IsLessThan[m1.This, Succ[N]] = proof

val res0 = tail(m1.n)(using IsLessThan.reduction(using proof)) // error // limitation
val res1 = tail(m1.n)(using IsLessThan.reduction(using proof1))
res1
17 changes: 17 additions & 0 deletions tests/neg/pdgadt-either.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
trait T1
trait T2 extends T1
// T2 <:< T1

trait Expr[+X]
case class Tag1() extends Expr[T1]
case class Tag2() extends Expr[T2]

trait TypeTag { type A }

def foo(p: TypeTag, e: Expr[p.A]) = e match
case _: (Tag2 | Tag1) =>
// Tag1: (T2 <:) T1 <: p.A
// Tag2: T2 <: p.A
// should choose T2 <: p.A
val t1: p.A = ??? : T1 // error
val t2: p.A = ??? : T2
Loading