Skip to content

Commit 6bfd47e

Browse files
committed
Use Skolems to infer GADT constraints
The rationale for using a Skolem here is: we want to record that there is at least one value that is both of the pattern type and the scrutinee type. All symbols are now considered valid for adding GADT constraints - the rationale is that set of constrainable symbols should be either selected on a per-(sub)pattern basis, or be the same for all matches. Previously, symbols which were only appearing variantly in a scrutinee type could be considered constrainable anyway because of an outer pattern match.
1 parent 5aaa9c8 commit 6bfd47e

File tree

6 files changed

+104
-25
lines changed

6 files changed

+104
-25
lines changed

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
10391039
* - If a type proxy P is not a reference to a class, P's supertype is in G
10401040
*/
10411041
def isSubTypeOfParent(subtp: Type, tp: Type)(implicit ctx: Context): Boolean =
1042-
if (constrainPatternType(subtp, tp)) true
1042+
if (constrainPatternType(SkolemType(subtp), tp)) true
10431043
else tp match {
10441044
case tp: TypeRef if tp.symbol.isClass => isSubTypeOfParent(subtp, tp.firstParent)
10451045
case tp: TypeProxy => isSubTypeOfParent(subtp, tp.superType)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ class Typer extends Namer
584584
def handlePattern: Tree = {
585585
val tpt1 = typedTpt
586586
if (!ctx.isAfterTyper && pt != defn.ImplicitScrutineeTypeRef)
587-
constrainPatternType(tpt1.tpe, pt)(ctx.addMode(Mode.GADTflexible))
587+
constrainPatternType(SkolemType(tpt1.tpe), pt)(ctx.addMode(Mode.GADTflexible))
588588
// special case for an abstract type that comes with a class tag
589589
tryWithClassTag(ascription(tpt1, isWildcard = true), pt)
590590
}
@@ -1017,16 +1017,12 @@ class Typer extends Namer
10171017
assignType(cpy.Match(tree)(sel, cases1), sel, cases1)
10181018
}
10191019

1020-
/** gadtSyms = "all type parameters of enclosing methods that appear
1021-
* non-variantly in the selector type" todo: should typevars
1022-
* which appear with variances +1 and -1 (in different
1023-
* places) be considered as well?
1024-
*/
1020+
/** gadtSyms = "all type parameters of enclosing methods appearing in selector type" */
10251021
def gadtSyms(selType: Type)(implicit ctx: Context): Set[Symbol] = trace(i"GADT syms of $selType", gadts) {
10261022
val accu = new TypeAccumulator[Set[Symbol]] {
10271023
def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = {
10281024
val tsyms1 = t match {
1029-
case tr: TypeRef if (tr.symbol is TypeParam) && tr.symbol.owner.isTerm && variance == 0 =>
1025+
case tr: TypeRef if (tr.symbol is TypeParam) && tr.symbol.owner.isTerm =>
10301026
tsyms + tr.symbol
10311027
case _ =>
10321028
tsyms
@@ -1041,7 +1037,11 @@ class Typer extends Namer
10411037
def gadtContext(gadtSyms: Set[Symbol])(implicit ctx: Context): Context = {
10421038
val gadtCtx = ctx.fresh.setFreshGADTBounds
10431039
for (sym <- gadtSyms)
1044-
if (!gadtCtx.gadt.contains(sym)) gadtCtx.gadt.addEmptyBounds(sym)
1040+
if (!gadtCtx.gadt.contains(sym)) {
1041+
val TypeBounds(lo, hi) = sym.info.bounds
1042+
gadtCtx.gadt.addBound(sym, lo, isUpper = false)
1043+
gadtCtx.gadt.addBound(sym, hi, isUpper = true)
1044+
}
10451045
gadtCtx
10461046
}
10471047

tests/neg/int-extractor.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
object Test {
2+
object EssaInt {
3+
def unapply(i: Int): Some[Int] = Some(i)
4+
}
5+
6+
def foo1[T](t: T): T = t match {
7+
case EssaInt(_) =>
8+
0 // error
9+
}
10+
11+
case class Inv[T](t: T)
12+
13+
def bar1[T](t: T): T = Inv(t) match {
14+
case Inv(EssaInt(_)) =>
15+
0 // error
16+
}
17+
}

tests/neg/invariant-gadt.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
object `invariant-gadt` {
2+
case class Invariant[T](value: T)
3+
4+
def unsound0[T](t: T): T = Invariant(t) match {
5+
case Invariant(_: Int) =>
6+
(0: Any) // error
7+
}
8+
9+
def unsound1[T](t: T): T = Invariant(t) match {
10+
case Invariant(_: Int) =>
11+
0 // error
12+
}
13+
14+
def unsound2[T](t: T): T = Invariant(t) match {
15+
case Invariant(value) => value match {
16+
case _: Int =>
17+
0 // error
18+
}
19+
}
20+
}

tests/pos/precise-pattern-type.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
object `precise-pattern-type` {
2+
class Type {
3+
def isType: Boolean = true
4+
}
5+
6+
class Tree[-T >: Null] {
7+
def tpe: T @annotation.unchecked.uncheckedVariance = ???
8+
}
9+
10+
case class Select[-T >: Null](qual: Tree[T]) extends Tree[T]
11+
12+
def test[T <: Tree[Type]](tree: T) = tree match {
13+
case Select(q) =>
14+
q.tpe.isType
15+
}
16+
}

tests/run/typeclass-derivation2.scala

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import scala.collection.mutable
22
import scala.annotation.tailrec
33

4+
object _typelevel {
5+
final abstract class Type[-A, +B]
6+
type Subtype[t] = Type[_, t]
7+
type Supertype[t] = Type[t, _]
8+
type Exactly[t] = Type[t, t]
9+
erased def typeOf[T]: Type[T, T] = ???
10+
}
11+
412
trait Deriving {
513
import Deriving._
614

@@ -177,6 +185,7 @@ trait Eq[T] {
177185
object Eq {
178186
import scala.typelevel._
179187
import Deriving._
188+
import _typelevel._
180189

181190
inline def tryEql[T](x: T, y: T) = implicit match {
182191
case eq: Eq[T] => eq.eql(x, y)
@@ -197,14 +206,19 @@ object Eq {
197206
inline def eqlCases[T, Alts <: Tuple](r: Reflected[T], x: T, y: T): Boolean =
198207
inline erasedValue[Alts] match {
199208
case _: (Shape.Case[alt, elems] *: alts1) =>
200-
x match {
201-
case x: `alt` =>
202-
y match {
203-
case y: `alt` => eqlCase[T, elems](r, x, y)
204-
case _ => false
209+
inline typeOf[alt] match {
210+
case _: Subtype[T] =>
211+
x match {
212+
case x: `alt` =>
213+
y match {
214+
case y: `alt` => eqlCase[T, elems](r, x, y)
215+
case _ => false
216+
}
217+
case _ => eqlCases[T, alts1](r, x, y)
205218
}
206-
case _ => eqlCases[T, alts1](r, x, y)
207-
}
219+
case _ =>
220+
error("invalid call to eqlCases: one of Alts is not a subtype of T")
221+
}
208222
case _: Unit =>
209223
false
210224
}
@@ -232,6 +246,7 @@ trait Pickler[T] {
232246
object Pickler {
233247
import scala.typelevel._
234248
import Deriving._
249+
import _typelevel._
235250

236251
def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1)
237252

@@ -253,12 +268,17 @@ object Pickler {
253268
inline def pickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T, n: Int): Unit =
254269
inline erasedValue[Alts] match {
255270
case _: (Shape.Case[alt, elems] *: alts1) =>
256-
x match {
257-
case x: `alt` =>
258-
buf += n
259-
pickleCase[T, elems](r, buf, x)
271+
inline typeOf[alt] match {
272+
case _: Subtype[T] =>
273+
x match {
274+
case x: `alt` =>
275+
buf += n
276+
pickleCase[T, elems](r, buf, x)
277+
case _ =>
278+
pickleCases[T, alts1](r, buf, x, n + 1)
279+
}
260280
case _ =>
261-
pickleCases[T, alts1](r, buf, x, n + 1)
281+
error("invalid pickleCases call: one of Alts is not a subtype of T")
262282
}
263283
case _: Unit =>
264284
}
@@ -323,6 +343,7 @@ trait Show[T] {
323343
object Show {
324344
import scala.typelevel._
325345
import Deriving._
346+
import _typelevel._
326347

327348
inline def tryShow[T](x: T): String = implicit match {
328349
case s: Show[T] => s.show(x)
@@ -347,9 +368,14 @@ object Show {
347368
inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String =
348369
inline erasedValue[Alts] match {
349370
case _: (Shape.Case[alt, elems] *: alts1) =>
350-
x match {
351-
case x: `alt` => showCase[T, elems](r, x)
352-
case _ => showCases[T, alts1](r, x)
371+
inline typeOf[alt] match {
372+
case _: Subtype[T] =>
373+
x match {
374+
case x: `alt` => showCase[T, elems](r, x)
375+
case _ => showCases[T, alts1](r, x)
376+
}
377+
case _ =>
378+
error("invalid call to showCases: one of Alts is not a subtype of T")
353379
}
354380
case _: Unit =>
355381
throw new MatchError(x)
@@ -424,4 +450,4 @@ object Test extends App {
424450
println(implicitly[Show[T]].show(x))
425451
showPrintln(xs)
426452
showPrintln(xss)
427-
}
453+
}

0 commit comments

Comments
 (0)