Skip to content

Commit 9084be9

Browse files
committed
Freeze GADTs more when comparing type member infos
1 parent 63344e7 commit 9084be9

File tree

12 files changed

+217
-5
lines changed

12 files changed

+217
-5
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1923,7 +1923,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
19231923
|| symInfo.isInstanceOf[MethodType]
19241924
&& symInfo.signature.consistentParams(info2.signature)
19251925

1926-
def tp1IsSingleton: Boolean = tp1.isInstanceOf[SingletonType]
1926+
def allowGadt: Boolean =
1927+
def rec(tp: Type): Boolean = tp match
1928+
case RefinedType(parent, name1, _) => name == name1 || rec(parent)
1929+
case tp: TypeRef => tp.symbol.isClass
1930+
case _ => false
1931+
!approx.low && rec(tp1)
19271932

19281933
// A relaxed version of isSubType, which compares method types
19291934
// under the standard arrow rule which is contravarient in the parameter types,
@@ -1939,8 +1944,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
19391944
matchingMethodParams(info1, info2, precise = false)
19401945
&& isSubInfo(info1.resultType, info2.resultType.subst(info2, info1), symInfo1.resultType)
19411946
&& sigsOK(symInfo1, info2)
1942-
case _ => inFrozenGadtIf(tp1IsSingleton) { isSubType(info1, info2) }
1943-
case _ => inFrozenGadtIf(tp1IsSingleton) { isSubType(info1, info2) }
1947+
case _ => inFrozenGadtIf(!allowGadt) { isSubType(info1, info2) }
1948+
case _ => inFrozenGadtIf(!allowGadt) { isSubType(info1, info2) }
19441949

19451950
def qualifies(m: SingleDenotation): Boolean =
19461951
val info1 = m.info.widenExpr

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,14 @@ object QuoteMatcher {
186186
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) &&
187187
tpt2.tpe.derivesFrom(defn.RepeatedParamClass) =>
188188
scrutinee match
189-
case Typed(s, tpt1) if s.tpe <:< tpt.tpe => matched(scrutinee)
189+
case Typed(s, tpt1) if patSub(s.tpe, tpt.tpe) => matched(scrutinee)
190190
case _ => notMatched
191191

192192
/* Term hole */
193193
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
194194
case TypeApply(patternHole, tpt :: Nil)
195195
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) &&
196-
scrutinee.tpe <:< tpt.tpe =>
196+
patSub(scrutinee.tpe, tpt.tpe) =>
197197
scrutinee match
198198
case ClosedPatternTerm(scrutinee) => matched(scrutinee)
199199
case _ => notMatched
@@ -480,4 +480,78 @@ object QuoteMatcher {
480480

481481
}
482482

483+
def patSub(scr: Type, pat: Type)(using Context): Boolean =
484+
val scrCls = scr.classSymbol
485+
val patCls = pat.classSymbol
486+
val upcPat = patCls.derivesFrom(scrCls)
487+
val upcScr = scrCls.derivesFrom(patCls)
488+
val tp = if upcScr then scr.refinedBaseType(patCls) else scr
489+
val pt = if upcPat then pat.refinedBaseType(scrCls) else pat
490+
tp <:< pt
491+
492+
import dotty.tools.dotc.*, core.*, cc.*, reporting.*, Decorators.*, SymDenotations.*
493+
extension (tp: Type) def refinedBaseType(base: Symbol)(using Context): Type = base.denot match
494+
case classd: ClassDenotation => classd.refinedBaseTypeOf(tp)
495+
case _ => NoType
496+
497+
extension (classd: ClassDenotation) def refinedBaseTypeOf(tp: Type)(using Context): Type =
498+
val symbol = classd.symbol
499+
def foldGlb(bt: Type, ps: List[Type]): Type = ps match
500+
case p :: ps1 => foldGlb(bt & recur(p), ps1)
501+
case _ => bt
502+
def recur(tp: Type): Type = trace(i"($tp).rbt($symbol)", show = true) {
503+
val normed = tp.tryNormalize
504+
if normed.exists then recur(normed) else tp match
505+
case tp @ TypeRef(prefix, _) =>
506+
val tpSym = tp.symbol
507+
tpSym.denot match
508+
case clsd: ClassDenotation =>
509+
def isOwnThis = prefix match
510+
case prefix: ThisType => prefix.cls eq clsd.owner
511+
case NoPrefix => true
512+
case _ => false
513+
if tpSym eq symbol then tp
514+
else if isOwnThis then
515+
if clsd.derivesFrom(symbol) then
516+
val base =
517+
if symbol.isStatic && symbol.typeParams.isEmpty then symbol.typeRef
518+
else foldGlb(NoType, clsd.info.parents)
519+
// change 1
520+
if base.exists then
521+
val custom = clsd.info.decls.filter(_.name.isTypeName)
522+
custom.foldRight(base)((sym, base) => RefinedType(base, sym.name, sym.info))
523+
else NoType
524+
else NoType
525+
else recur(clsd.typeRef).asSeenFrom(prefix, clsd.owner)
526+
case _ => recur(tp.superTypeNormalized)
527+
case tp @ AppliedType(tycon, args) =>
528+
if tycon.typeSymbol eq symbol then tp
529+
else (tycon.typeParams: @unchecked) match
530+
case LambdaParam(_, _) :: _ => recur(tp.superTypeNormalized)
531+
case tparams: List[Symbol @unchecked] => recur(tycon).substApprox(tparams, args)
532+
case tp: TypeParamRef => recur(TypeComparer.bounds(tp).hi)
533+
case CapturingType(parent, refs) => tp.derivedCapturingType(recur(parent), refs)
534+
case tp @ RefinedType(parent, name, info) =>
535+
// change 2
536+
val parent1 = recur(parent)
537+
if parent1.exists then tp.derivedRefinedType(parent1, name, info)
538+
else NoType
539+
case tp: TypeProxy => recur(tp.superTypeNormalized)
540+
case tp: AndOrType =>
541+
val tp1 = tp.tp1
542+
val tp2 = tp.tp2
543+
if !tp.isAnd && tp1.isBottomType && (tp1 frozen_<:< tp2) then recur(tp2)
544+
else if !tp.isAnd && tp2.isBottomType && (tp2 frozen_<:< tp1) then recur(tp1)
545+
else
546+
val baseTp1 = recur(tp1)
547+
val baseTp2 = recur(tp2)
548+
val combined = if tp.isAnd then baseTp1 & baseTp2 else baseTp1 | baseTp2
549+
combined match
550+
case combined: AndOrType
551+
if (combined.tp1 eq tp1) && (combined.tp2 eq tp2) && (combined.isAnd == tp.isAnd) => tp
552+
case _ => combined
553+
case JavaArrayType(_) if symbol == defn.ObjectClass => classd.typeRef
554+
case _ => NoType
555+
}
556+
recur(tp)
483557
}

tests/neg/i15485.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
enum SUB[-L, +R]:
2+
case Refl[C]() extends SUB[C, C]
3+
4+
trait Tag { type T }
5+
6+
def foo[A, B, X <: Tag { type T <: A } ](
7+
e: SUB[X, Tag { type T <: B }],
8+
x: A,
9+
): B = e match {
10+
case SUB.Refl() =>
11+
// SUB.Refl.unapply[?C](e)
12+
// ?C >: X => cstr: C = X..Any
13+
// ?C <: Tag { T = Nothing..B } => cstr: C = X..Tag { T = Nothing..B }
14+
// SUB[Tag { T = Nothing..Int }, Tag { T = Nothing..String }]
15+
// A = Int
16+
// B = String
17+
// X = Tag { T = Nothing..Nothing }
18+
// X <: Tag { T = Nothing..A }
19+
// SUB[X, Tag { T = Nothing..B }]
20+
// SUB[Tag { T = Nothing..A }, Tag { T = Nothing..B }], approxLHS
21+
// Tag { T = Nothing..A } <: C <: Tag { T = Nothing..B }]
22+
// Tag { T = Nothing..A } <: Tag { T = Nothing..B }
23+
// A <: B
24+
x // error: Found: (x: A) Required: B
25+
}
26+
27+
def bad(x: Int): String =
28+
foo[Int, String, Tag { type T = Nothing }](SUB.Refl(), x) // cast Int to String
29+
30+
object Test:
31+
def main(args: Array[String]): Unit = bad(1) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String

tests/neg/i15485b.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
enum SUB[-A, +B]:
2+
case Refl[C]() extends SUB[C, C]
3+
4+
trait Tag { type T }
5+
6+
def foo[L, H, X <: Tag { type T >: L <: H }](
7+
e: SUB[X, Tag { type T = Int }],
8+
x: Int,
9+
): L = e match {
10+
case SUB.Refl() =>
11+
// X <: C and C <: Tag { T = Int }
12+
// X <: Tag { T = Int }
13+
// Tag { T >: L <: H } <: Tag { T = Int }
14+
// Int <: L and H <: Int
15+
x // error
16+
}
17+
18+
def bad(x: Int): String =
19+
foo[Nothing, Any, Tag { type T = Int }](SUB.Refl(), x) // cast Int to String!
20+
21+
object Test:
22+
def main(args: Array[String]): Unit = bad(1) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String

tests/neg/i15485c.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
enum SUB[-A, +B]:
2+
case Refl[C]() extends SUB[C, C]
3+
4+
trait Tag { type T }
5+
6+
def foo[L](g: Tag { type T >: L <: Int })(
7+
e: SUB[g.type, Tag { type T = Int }],
8+
x: Int,
9+
): L = e match {
10+
case SUB.Refl() =>
11+
// L = Nothing
12+
// C = t
13+
// g := Tag { T = Int..Int }
14+
// g <: Tag { T = Nothing..Int }
15+
// SUB[g, Tag { T = Int..Int }]
16+
// SUB[Tag { T = Nothing..Int }, Tag { T = Int..Int }]
17+
// SUB[Tag { T = L..Int }, Tag { T = Int..Int }] <:< SUB[C, C]
18+
// Tag { T = L..Int } <: C <: Tag { T = Int..Int }]
19+
// Tag { T = L..Int } <: Tag { T = Int..Int }
20+
// Int <: L
21+
x // error
22+
}
23+
24+
def bad(x: Int): String =
25+
val s: Tag { type T = Int } = new Tag { type T = Int }
26+
val t: Tag { type T >: Nothing <: Int } & s.type = s
27+
val e: SUB[t.type, Tag { type T = Int }] = SUB.Refl[t.type]()
28+
foo[Nothing](t)(e, x) // cast Int to String!
29+
30+
object Test:
31+
def main(args: Array[String]): Unit = bad(1) // was: ClassCastException: class java.lang.Integer cannot be cast to class java.lang.String
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Another minimisation (after tests/run-macros/i15485.fallout-monocle)
2+
// of monocle's GenIsoSpec.scala
3+
// which broke when fixing soundness in infering GADT constraints on refined types
4+
class Can[T]
5+
object Can:
6+
import scala.deriving.*, scala.quoted.*
7+
8+
inline given derived[T](using inline m: Mirror.Of[T]): Can[T] = ${ impl('m) }
9+
10+
private def impl[T](m: Expr[Mirror.Of[T]])(using Quotes, Type[T]): Expr[Can[T]] = m match
11+
case '{ $_ : Mirror.Sum { type MirroredElemTypes = met } } => '{ new Can[T] }
12+
case '{ $_ : Mirror.Product { type MirroredElemTypes = met } } => '{ new Can[T] }
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class Test:
2+
def test =
3+
Can.derived[EmptyTuple]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.deriving.*, scala.quoted.*
2+
3+
object Iso:
4+
transparent inline def fields[S <: Product](using m: Mirror.ProductOf[S]): Int = ${ Impl.apply[S]('m) }
5+
6+
object Impl:
7+
def apply[S <: Product](m: Expr[Mirror.ProductOf[S]])(using Quotes, Type[S]): Expr[Int] =
8+
import quotes.reflect.*
9+
m match
10+
case '{ type a <: Tuple; $m: Mirror.ProductOf[S] { type MirroredElemTypes = `a` } } => '{ 1 }
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class Test:
2+
def test =
3+
case object Foo
4+
val iso = Iso.fields[Foo.type]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.deriving.*, scala.quoted.*
2+
3+
trait Foo[T]:
4+
def foo: Int
5+
6+
// A minimisation of monocle's GenIsoSpec.scala
7+
// which broke when fixing soundness in infering GADT constraints on refined types
8+
object Foo:
9+
inline given derived[T](using inline m: Mirror.Of[T]): Foo[T] = ${ impl('m) }
10+
11+
private def impl[T](m: Expr[Mirror.Of[T]])(using qctx: Quotes, tpe: Type[T]): Expr[Foo[T]] = m match
12+
case '{ $m : Mirror.Product { type MirroredElemTypes = EmptyTuple } } => '{ FooN[T](0) }
13+
case '{ $m : Mirror.Product { type MirroredElemTypes = a *: EmptyTuple } } => '{ FooN[T](1) }
14+
case '{ $m : Mirror.Product { type MirroredElemTypes = mirroredElemTypes } } => '{ FooN[T](9) }
15+
16+
class FooN[T](val foo: Int) extends Foo[T]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
final case class Box(value: Int) derives Foo
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
@main def Test =
2+
val foo = summon[Foo[Box]].foo
3+
assert(foo == 1, foo)

0 commit comments

Comments
 (0)