Skip to content

Commit 97ef5bc

Browse files
committed
handle structural types/type members
1 parent 2b2af9f commit 97ef5bc

7 files changed

+480
-40
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 92 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
179179
* code would have two extra parameters for each of the many calls that go from
180180
* one sub-part of isSubType to another.
181181
*/
182-
protected def recur(tp1: Type, tp2: Type): Boolean = trace(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping) {
182+
protected def recur(tp1: Type, tp2: Type): Boolean =
183+
// trace.force(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping)
184+
{
183185

184186
def monitoredIsSubType = {
185187
if (pendingSubTypes == null) {
@@ -2104,40 +2106,31 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
21042106
import config.Printers.debug
21052107
import typer.Inferencing._
21062108

2107-
def incompatibleClasses: Boolean = {
2109+
def compatibleClasses: Boolean = {
21082110
import Flags._
21092111
val tpClassSym = tp.widenSingleton.classSymbol
21102112
val ptClassSym = pt.widenSingleton.classSymbol
21112113
debug.println(i"tpClassSym=$tpClassSym, fin=${tpClassSym.is(Final)}")
21122114
debug.println(i"pt=$pt {${pt.getClass}}, ptClassSym=$ptClassSym, fin=${ptClassSym.is(Final)}")
2113-
tpClassSym.exists && ptClassSym.exists && {
2114-
if (tpClassSym.is(Final)) !tpClassSym.derivesFrom(ptClassSym)
2115-
else if (ptClassSym.is(Final)) !ptClassSym.derivesFrom(tpClassSym)
2115+
!tpClassSym.exists || !ptClassSym.exists || {
2116+
if (tpClassSym.is(Final)) tpClassSym.derivesFrom(ptClassSym)
2117+
else if (ptClassSym.is(Final)) ptClassSym.derivesFrom(tpClassSym)
21162118
else if (!tpClassSym.is(Flags.Trait) && !ptClassSym.is(Flags.Trait))
2117-
!(tpClassSym.derivesFrom(ptClassSym) || ptClassSym.derivesFrom(tpClassSym))
2118-
else false
2119+
tpClassSym.derivesFrom(ptClassSym) || ptClassSym.derivesFrom(tpClassSym)
2120+
else true
21192121
}
21202122
}
21212123

21222124
def loop(tp: Type): Boolean =
21232125
// trace.force(i"loop($tp) // ${tp.toString}")
21242126
{
2125-
if (constrainPatternType(pt, tp)) true
2126-
else if (incompatibleClasses) {
2127-
// println("incompatible classes")
2128-
false
2129-
}
2130-
else tp match {
2131-
case _: ConstantType =>
2132-
// constants cannot possibly intersect with types that aren't their supertypes
2133-
false
2134-
case tp: SingletonType => loop(tp.underlying)
2135-
case tp: TypeRef if tp.symbol.isClass => loop(tp.firstParent)
2127+
val res: Type = tp match {
2128+
case tp: TypeRef if tp.symbol.isClass => tp.firstParent
21362129
case tp @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
21372130
val ptClassSym = pt.classSymbol
21382131
def firstParentSharedWithPt(tp: Type, tpClassSym: ClassSymbol): Symbol =
2139-
// trace.force(i"f($tp)")
2140-
{
2132+
// trace.force(i"f($tp)")
2133+
{
21412134
var parents = tpClassSym.info.parents
21422135
// println(i"parents of $tpClassSym = $parents%, %")
21432136
parents match {
@@ -2156,29 +2149,89 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
21562149
}
21572150
val sym = firstParentSharedWithPt(tycon, tycon.symbol.asClass)
21582151
// println(i"sym=$sym ; tyconsym=${tycon.symbol}")
2159-
if (!sym.exists) true
2160-
else !(sym == tycon.symbol) && loop(tp.baseType(sym))
2152+
if (!sym.exists) return true
2153+
// else !(sym == tycon.symbol) &&
2154+
tp.baseType(sym)
21612155
case tp: TypeProxy =>
2162-
loop(tp.superType)
2163-
case _ => false
2156+
tp.superType
2157+
case _ => return true
21642158
}
2159+
constrainPatternType(pt, res) || loop(res)
21652160
}
21662161

2167-
pt match {
2168-
case AndType(pt1, pt2) =>
2169-
notIntersection(tp, pt1) && notIntersection(tp, pt2)
2170-
case OrType(pt1, pt2) =>
2171-
either(notIntersection(tp, pt1), notIntersection(tp, pt2))
2172-
case _ =>
2173-
tp match {
2174-
case OrType(tp1, tp2) =>
2175-
either(notIntersection(tp1, pt), notIntersection(tp2, pt))
2176-
case AndType(tp1, tp2) =>
2177-
notIntersection(tp1, pt) && notIntersection(tp2, pt)
2178-
case _ =>
2179-
loop(tp)
2180-
}
2181-
}
2162+
tp.dealias match {
2163+
case OrType(tp1, tp2) =>
2164+
either(notIntersection(tp1, pt), notIntersection(tp2, pt))
2165+
case AndType(tp1, tp2) =>
2166+
notIntersection(tp1, pt) && notIntersection(tp2, pt)
2167+
case tp: RefinedOrRecType =>
2168+
def keepInvariantRefinements(tp: Type): Type = tp match {
2169+
case tp: RefinedType =>
2170+
if (tp.refinedName.isTermName) keepInvariantRefinements(tp.parent)
2171+
else {
2172+
// def resolve(tp: Type): Type = tp match {
2173+
// case TypeAlias(tp) => resolve(tp.dealias)
2174+
// case tp => tp
2175+
// }
2176+
// val tpInfo = tp.refinedInfo
2177+
// val tpInfoDealiased = resolve(tpInfo)
2178+
// val ptInfo = pt.member(tp.refinedName).info
2179+
// val ptInfoDealiased = resolve(ptInfo)
2180+
// println(
2181+
// i"""tpInfo = ${tpInfo}
2182+
// |tpInfoDealiased = ${tpInfoDealiased}
2183+
// |ptInfo = ${ptInfo}
2184+
// |ptInfoDealiased = ${ptInfoDealiased}""".stripMargin
2185+
// )
2186+
// println(i"visiting refinement: ${tp.refinedName} : ${tp.refinedInfo}")
2187+
tp.refinedInfo match {
2188+
case TypeAlias(tp1) =>
2189+
val pt1 = pt.member(tp.refinedName).info
2190+
if (pt1.exists && pt1.bounds.contains(tp1) || !pt1.exists)
2191+
keepInvariantRefinements(tp.parent)
2192+
else
2193+
NoType
2194+
case tpb: TypeBounds =>
2195+
pt.member(tp.refinedName).info match {
2196+
case TypeAlias(pt1) =>
2197+
if (tpb.contains(pt1))
2198+
keepInvariantRefinements(tp.parent)
2199+
else
2200+
NoType
2201+
case _ =>
2202+
keepInvariantRefinements(tp.parent)
2203+
}
2204+
}
2205+
}
2206+
case tp: RecType =>
2207+
keepInvariantRefinements(tp.parent)
2208+
case _ =>
2209+
tp
2210+
}
2211+
val tp1 = keepInvariantRefinements(tp)
2212+
if (!tp1.exists) {
2213+
// println(i"noType for $tp")
2214+
false
2215+
} else
2216+
notIntersection(tp1, pt)
2217+
case tp =>
2218+
pt.dealias match {
2219+
case AndType(pt1, pt2) =>
2220+
notIntersection(tp, pt1) && notIntersection(tp, pt2)
2221+
case OrType(pt1, pt2) =>
2222+
either(notIntersection(tp, pt1), notIntersection(tp, pt2))
2223+
case pt: RefinedOrRecType =>
2224+
// note: at this point, we have already extracted the information we wanted from the refinement
2225+
// and it would only interfere in the following subtype check in constrainPatternType
2226+
def stripRefinement(tp: Type): Type = tp match {
2227+
case tp: RefinedOrRecType => stripRefinement(tp.parent)
2228+
case tp => tp
2229+
}
2230+
notIntersection(tp, stripRefinement(pt))
2231+
case pt =>
2232+
constrainPatternType(pt, tp) || compatibleClasses && loop(tp)
2233+
}
2234+
}
21822235
}
21832236
}
21842237

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,11 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
11081108
fullyDefinedType(unapplyArgType, "pattern selector", tree.span)
11091109
selType
11101110
} else {
1111-
isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))
1111+
val res = isSubTypeOfParent(unapplyArgType, selType)(ctx.addMode(Mode.GADTflexible))
1112+
if (!res) ctx.warning(
1113+
ex"Pattern type $unapplyArgType does not intersect selector type $selType",
1114+
tree.sourcePos
1115+
)
11121116
val patternBound = maximizeType(unapplyArgType, tree.span, fromScala2x)
11131117
if (patternBound.nonEmpty) unapplyFn = addBinders(unapplyFn, patternBound)
11141118
unapp.println(i"case 2 $unapplyArgType ${ctx.typerState.constraint}")

tests/neg/structural-gadt.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
object Test {
2+
trait Expr { type T }
3+
trait IntLit extends Expr { type T <: Int }
4+
trait IntExpr extends Expr { type T = Int }
5+
6+
def foo[A](e: Expr { type T = A }) = e match {
7+
case _: IntLit =>
8+
val a: A = 0 // error
9+
val i: Int = ??? : A
10+
11+
case _: Expr { type T <: Int } =>
12+
val a: A = 0 // error
13+
val i: Int = ??? : A
14+
15+
case _: IntExpr =>
16+
val a: A = 0
17+
val i: Int = ??? : A
18+
19+
case _: Expr { type T = Int } =>
20+
val a: A = 0
21+
val i: Int = ??? : A
22+
}
23+
24+
def bar[A](e: Expr { type T <: A }) = e match {
25+
case _: IntLit =>
26+
val a: A = 0 // error
27+
val i: Int = ??? : A // error
28+
29+
case _: Expr { type T <: Int } =>
30+
val a: A = 0 // error
31+
val i: Int = ??? : A // error
32+
33+
case _: IntExpr =>
34+
val a: A = 0
35+
val i: Int = ??? : A // error
36+
37+
case _: Expr { type T = Int } =>
38+
val a: A = 0
39+
val i: Int = ??? : A // error
40+
}
41+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
object Test {
2+
// Some error comments in this file are preceded by // ?
3+
// This indicates that we should actually accept that line,
4+
// but we don't due to limitation of the implementation
5+
//
6+
//
7+
8+
trait Expr { type T }
9+
trait IntLit extends Expr { type T <: Int }
10+
trait IntExpr extends Expr { type T = Int }
11+
12+
type ExprSub[+A] = Expr { type T <: A }
13+
type ExprExact[A] = Expr { type T = A }
14+
15+
trait IndirectIntLit extends Expr { type S <: Int; type T = S }
16+
trait IndirectIntExpr extends Expr { type S = Int; type T = S }
17+
18+
type IndirectExprSub[+A] = Expr { type S <: A; type T = S }
19+
type IndirectExprSub2[A] = Expr { type S = A; type T <: S }
20+
type IndirectExprExact[A] = Expr { type S = A; type T = S }
21+
22+
trait AltIndirectIntLit extends Expr { type U <: Int; type T = U }
23+
trait AltIndirectIntExpr extends Expr { type U = Int; type T = U }
24+
25+
type AltIndirectExprSub[+A] = Expr { type U <: A; type T = U }
26+
type AltIndirectExprSub2[A] = Expr { type U = A; type T <: U }
27+
type AltIndirectExprExact[A] = Expr { type U = A; type T = U }
28+
29+
def foo[A](e: IndirectExprExact[A]) = e match {
30+
case _: AltIndirectIntLit =>
31+
val a: A = 0 // error
32+
val i: Int = ??? : A
33+
34+
case _: AltIndirectExprSub[Int] =>
35+
val a: A = 0 // error
36+
val i: Int = ??? : A
37+
38+
case _: AltIndirectExprSub2[Int] =>
39+
val a: A = 0 // error
40+
val i: Int = ??? : A
41+
42+
case _: AltIndirectIntExpr =>
43+
val a: A = 0
44+
val i: Int = ??? : A
45+
46+
case _: AltIndirectExprExact[Int] =>
47+
val a: A = 0
48+
val i: Int = ??? : A
49+
}
50+
51+
def bar[A](e: IndirectExprSub[A]) = e match {
52+
case _: AltIndirectIntLit =>
53+
val a: A = 0 // error
54+
val i: Int = ??? : A // error
55+
56+
case _: AltIndirectExprSub[Int] =>
57+
val a: A = 0 // error
58+
val i: Int = ??? : A // error
59+
60+
case _: AltIndirectExprSub2[Int] =>
61+
val a: A = 0 // error
62+
val i: Int = ??? : A // error
63+
64+
case _: AltIndirectIntExpr =>
65+
val a: A = 0 // ? // error
66+
val i: Int = ??? : A // error
67+
68+
case _: AltIndirectExprExact[Int] =>
69+
val a: A = 0 // ? // error
70+
val i: Int = ??? : A // error
71+
}
72+
73+
def baz[A](e: IndirectExprSub2[A]) = e match {
74+
case _: AltIndirectIntLit =>
75+
val a: A = 0 // error
76+
val i: Int = ??? : A // error
77+
78+
case _: AltIndirectExprSub[Int] =>
79+
val a: A = 0 // error
80+
val i: Int = ??? : A // error
81+
82+
case _: AltIndirectExprSub2[Int] =>
83+
val a: A = 0 // error
84+
val i: Int = ??? : A // error
85+
86+
case _: AltIndirectIntExpr =>
87+
val a: A = 0
88+
val i: Int = ??? : A // error
89+
90+
case _: AltIndirectExprExact[Int] =>
91+
val a: A = 0
92+
val i: Int = ??? : A // error
93+
}
94+
}

0 commit comments

Comments
 (0)