Skip to content

Commit a1eb832

Browse files
committed
Check bounds in match type case bodies
Requires a few things. In the tests, it requires propagating some bounds, as well as tweaking how things are matched. Also, requires a few changes in how type patterns add constraints, with a fix on type constructors and another guard in widening abstract types.
1 parent 09f5e4c commit a1eb832

19 files changed

+95
-57
lines changed

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
7373
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
7474
* in which case the subtyping relationship "heals" the type.
7575
*/
76-
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
76+
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($pat, $scrut)", gadts) {
7777

7878
def classesMayBeCompatible: Boolean = {
7979
import Flags._
@@ -231,41 +231,32 @@ trait PatternTypeConstrainer { self: TypeComparer =>
231231
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
232232
*/
233233
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
234+
val debug = noPrinter
234235
def refinementIsInvariant(tp: Type): Boolean = tp match {
235236
case tp: SingletonType => true
236237
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
237238
case tp: TypeProxy => refinementIsInvariant(tp.superType)
238239
case _ => false
239240
}
240241

241-
def widenVariantParams(tp: Type) = tp match {
242-
case tp @ AppliedType(tycon, args) =>
243-
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
244-
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
245-
)
246-
tp.derivedAppliedType(tycon, args1)
247-
case tp =>
248-
tp
249-
}
250-
251242
val patternCls = patternTp.classSymbol
252243
val scrutineeCls = scrutineeTp.classSymbol
253244

254245
// NOTE: we already know that there is a derives-from relationship in either direction
255246
val upcastPattern =
256247
patternCls.derivesFrom(scrutineeCls)
257248

258-
val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
259-
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
249+
val pat = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
250+
val scr = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
260251

261252
val assumeInvariantRefinement =
262253
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
263254

264-
trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
265-
(tp, pt) match {
266-
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
255+
trace(i"constraining simple pattern type $pat >:< $scr assume=$assumeInvariantRefinement", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
256+
(scr, pat) match {
257+
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) if tyconP.frozen_=:=(tyconS) =>
267258
val saved = state.nn.constraint
268-
val result =
259+
val success =
269260
ctx.gadtState.rollbackGadtUnless {
270261
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
271262
val variance = param.paramVarianceSign
@@ -277,15 +268,21 @@ trait PatternTypeConstrainer { self: TypeComparer =>
277268
val TypeBounds(loS, hiS) = argS.bounds
278269
val TypeBounds(loP, hiP) = argP.bounds
279270
var res = true
280-
if variance < 1 then res &&= isSubType(loS, hiP)
281-
if variance > -1 then res &&= isSubType(loP, hiS)
271+
if ctx.mode.is(Mode.Type) then
272+
if variance > -1 then res &&= isSubType(loS, hiP).showing(i"$loS <: $hiP = $result v=$variance argS=$argS argP=$argP", debug)
273+
if variance < 1 then res &&= isSubType(loP, hiS).showing(i"$hiS >: $loP = $result v=$variance argS=$argS argP=$argP", debug)
274+
else
275+
if variance < 1 then res &&= isSubType(loS, hiP).showing(i"$hiP >: $loS = $result v=$variance argP=$argP argS=$argS", debug)
276+
if variance > -1 then res &&= isSubType(loP, hiS).showing(i"$loP <: $hiS = $result v=$variance argP=$argP argS=$argS", debug)
282277
res
283278
else true
284279
}
285280
}
286-
if !result then
281+
if !success then
287282
constraint = saved
288-
result
283+
success
284+
case (scr: TypeRef, _) if ctx.mode.is(Mode.Type) && ctx.gadt.contains(scr.symbol) =>
285+
isSubType(scrutineeTp, patternTp).showing(i"$scrutineeTp <: $patternTp = $result", debug)
289286
case _ =>
290287
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.
291288
// Note that this doesn't mean that patternTp, scrutineeTp cannot possibly

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
912912
case info1 @ TypeBounds(lo1, hi1) =>
913913
def compareGADT =
914914
tp1.symbol.onGadtBounds(gbounds1 =>
915-
isSubTypeWhenFrozen(gbounds1.hi, tp2)
915+
(!caseLambda.exists || widenAbstractOKFor(tp2)) && isSubTypeWhenFrozen(gbounds1.hi, tp2)
916916
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true))
917917
&& (tp2.isAny || GADTusage(tp1.symbol))
918918

@@ -3117,7 +3117,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31173117
super.typeVarInstance(tvar)
31183118
}
31193119

3120-
def matchCases(scrut: Type, cases: List[Type])(using Context): Type = {
3120+
def matchCases(scrut: Type, cases: List[Type])(using Context): Type = trace(i"matchCases($scrut, $cases)") {
31213121
// a reference for the type parameters poisoned during matching
31223122
// for use during the reduction step
31233123
var poisoned: Set[TypeParamRef] = Set.empty
@@ -3169,7 +3169,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31693169

31703170
val defn.MatchCase(pat, body) = cas1: @unchecked
31713171

3172-
def matches(canWidenAbstract: Boolean): Boolean =
3172+
def matches(canWidenAbstract: Boolean): Boolean = trace(i"matches(canWidenAbstract=$canWidenAbstract)") {
31733173
val saved = this.canWidenAbstract
31743174
val savedPoisoned = this.poisoned
31753175
this.canWidenAbstract = canWidenAbstract
@@ -3179,8 +3179,9 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31793179
poisoned = this.poisoned
31803180
this.poisoned = savedPoisoned
31813181
this.canWidenAbstract = saved
3182+
}
31823183

3183-
def redux(canApprox: Boolean): MatchResult =
3184+
def redux(canApprox: Boolean): MatchResult = trace(i"redux(canApprox=$canApprox)") {
31843185
caseLambda match
31853186
case caseLambda: HKTypeLambda =>
31863187
val instances = paramInstances(canApprox)(Array.fill(caseLambda.paramNames.length)(NoType), pat)
@@ -3195,6 +3196,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31953196
MatchResult.Reduced(redux.simplified)
31963197
case _ =>
31973198
MatchResult.Reduced(body)
3199+
}
31983200

31993201
if caseLambda.exists && matches(canWidenAbstract = false) then
32003202
redux(canApprox = true)

compiler/src/dotty/tools/dotc/core/TypeEval.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ object TypeEval:
2121
case tp: TypeProxy =>
2222
val tp1 = tp.superType
2323
if tp1.isStable then tp1.fixForEvaluation else tp
24+
case tp: AndType =>
25+
// tests/pos/9890.scala
26+
// allow `((0 : Int) & Int) * (3 : Int)` to be folded
27+
val glb = tp.tp1 & tp.tp2
28+
if tp ne glb then glb.fixForEvaluation else tp
2429
case tp => tp
2530

2631
def constValue(tp: Type): Option[Any] = tp.fixForEvaluation match

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,16 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
269269
if !tree.symbol.is(Package) then tree
270270
else errorTree(tree, em"${tree.symbol} cannot be used as a type")
271271

272+
private def gadtCtx(tree: CaseDef)(using Context): Context =
273+
tree.pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
274+
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
275+
case None => ctx
276+
272277
override def transform(tree: Tree)(using Context): Tree =
273278
try tree match {
274279
// TODO move CaseDef case lower: keep most probable trees first for performance
275-
case CaseDef(pat, _, _) =>
276-
val gadtCtx =
277-
pat.removeAttachment(typer.Typer.InferredGadtConstraints) match
278-
case Some(gadt) => ctx.fresh.setGadtState(GadtState(gadt))
279-
case None =>
280-
ctx
281-
super.transform(tree)(using gadtCtx)
280+
case tree: CaseDef =>
281+
super.transform(tree)(using gadtCtx(tree))
282282
case tree: Ident =>
283283
if tree.isType then
284284
checkNotPackage(tree)
@@ -477,7 +477,10 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
477477
case m @ MatchTypeTree(bounds, selector, cases) =>
478478
// Analog to the case above for match types
479479
def transformIgnoringBoundsCheck(x: CaseDef): CaseDef =
480-
withMode(Mode.Pattern)(super.transform(x)).asInstanceOf[CaseDef]
480+
inContext(gadtCtx(x)) {
481+
val pat1 = inMode(Mode.Pattern)(transform(x.pat))
482+
cpy.CaseDef(tree)(pat1, transform(x.guard), transform(x.body))
483+
}
481484
cpy.MatchTypeTree(tree)(
482485
super.transform(bounds),
483486
super.transform(selector),

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,8 @@ class Namer { typer: Typer =>
974974

975975
override final def typeSig(sym: Symbol): Type =
976976
val tparamSyms = completerTypeParams(sym)(using ictx)
977-
given ctx: Context = nestedCtx.nn
977+
given ctx: Context = nestedCtx.nn.fresh.setFreshGADTBounds
978+
if tparamSyms.nonEmpty then ctx.gadtState.addToConstraint(tparamSyms)
978979

979980
def abstracted(tp: TypeBounds): TypeBounds =
980981
HKTypeLambda.boundsFromParams(tparamSyms, tp)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,13 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
18471847
else report.error(new DuplicateBind(b, cdef), b.srcPos)
18481848
if (!ctx.isAfterTyper) {
18491849
val bounds = ctx.gadt.fullBounds(sym)
1850-
if (bounds != null) sym.info = bounds
1850+
if (bounds != null)
1851+
val info = if ctx.mode.is(Mode.Type) then bounds match
1852+
case TypeBounds(lo, hi) if !lo.isExactlyNothing && hi.isExactlyAny => TypeBounds(defn.NothingType, lo)
1853+
case TypeAlias(_) => sym.info
1854+
case bounds => bounds
1855+
else bounds
1856+
sym.info = info
18511857
}
18521858
b
18531859
case t: UnApply if t.symbol.is(Inline) => Inlines.inlinedUnapply(t)
@@ -1916,6 +1922,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19161922
}
19171923
val pat2 = indexPattern(cdef).transform(pat1)
19181924
var body1 = typedType(cdef.body, pt)
1925+
if ctx.gadt.isNarrowing then
1926+
pat1.putAttachment(InferredGadtConstraints, ctx.gadt)
19191927
if !body1.isType then
19201928
assert(ctx.reporter.errorsReported)
19211929
body1 = TypeTree(errorType(em"<error: not a type>", cdef.srcPos))

tests/neg/6570.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ object ThisTypeVariant {
6969
}
7070

7171
object ParametricVariant {
72-
type Trick[a] = { type A <: a }
72+
type Trick[a] = Any { type A <: a }
7373
type M[t] = t match { case Trick[a] => N[a] }
7474

7575
trait Root[B] {

tests/neg/i13741.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
type Init[X <: NonEmptyTuple] <: Tuple = X match
2+
case _ *: EmptyTuple => EmptyTuple
3+
case x *: xs =>
4+
x *: Init[xs] // error
5+
6+
def a: Init[Tuple3[Int, String, Boolean]] = ???

tests/neg/i15272.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
case class Head[+NT, +From <: NT, +To <: NT] (from: From, to: To ) extends EdgeN[NT]
44
case class Cons[+NT, +From <: NT, +ToE <: EdgeN[NT]](from: From, to: ToE) extends EdgeN[NT]
55
final type InNodesTupleOf[NT, E <: EdgeN[NT]] <: Tuple = E match
6-
case Cons[nt,from,toE] => from *: InNodesTupleOf[nt,toE]
6+
case Cons[nt,from,toE] => from *: InNodesTupleOf[nt,toE] // error
77
case Head[nt,from ,to] => from *: EmptyTuple
88
def inNodesTuple[NT,E <: EdgeN[NT]](edge: E): InNodesTupleOf[NT,E] = edge match
99
case e: Cons[nt,from,toE] => e.from *: inNodesTuple[nt,toE](e.to) // error
1010
case e: Head[nt,from,to] => e.from *: EmptyTuple
11-
end EdgeN
11+
end EdgeN

tests/pos/10867.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
object Test {
2+
// e.g inserts[z, (a, b)] =:= ((z, a, b), (a, z, b), (a, b, z))
23
type inserts[a, as <: Tuple] <: Tuple =
34
as match
45
case EmptyTuple => (a *: EmptyTuple) *: EmptyTuple
5-
case y *: ys => (a *: y *: ys) *: Tuple.Map[inserts[a, ys], [t <: Tuple] =>> y *: t]
6+
case y *: ys => (a *: y *: ys) *: Tuple.Map[inserts[a, ys], [t <: Tuple.Union[inserts[a, ys]]] =>> y *: (t & Tuple)]
67

78
type inserts2[a] =
89
[as <: Tuple] =>> inserts[a, as]

tests/pos/13633.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ object Sums extends App:
2121

2222
type Reverse[A] = ReverseLoop[A, EmptyTuple]
2323

24-
type PlusTri[A, B, C] = (A, B, C) match
24+
type PlusTri[A, B, C] <: (Boolean, Boolean) = (A, B, C) match
2525
case (false, false, false) => (false, false)
2626
case (true, false, false) | (false, true, false) | (false, false, true) => (false, true)
2727
case (true, true, false) | (true, false, true) | (false, true, true) => (true, false)
@@ -38,7 +38,7 @@ object Sums extends App:
3838
case false => A
3939
case true => Inc[A]
4040

41-
type PlusLoop[A <: Tuple, B <: Tuple, O] <: Tuple = (A, B) match
41+
type PlusLoop[A <: Tuple, B <: Tuple, O <: Boolean] <: Tuple = (A, B) match
4242
case (EmptyTuple, EmptyTuple) =>
4343
O match
4444
case true => (true *: EmptyTuple)

tests/pos/9239.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@ object ABug:
1010
type Zero = B0 :: Nil
1111
type One = B1 :: Nil
1212

13-
type --[B <: Bin] =
13+
type --[B <: Bin] <: Bin =
1414
B match
1515
case B1 :: d => B0 :: d
1616
case B0 :: B1 :: Nil => B1 :: Nil
1717
case B0 :: d => B1 :: --[d]
1818

19-
type ×[N <: Bin, M <: Bin] =
19+
type ×[N <: Bin, M <: Bin] <: Bin =
2020
(N, M) match
2121
case (Zero, ?) => Zero
2222

23-
type ![N <: Bin] =
23+
type ![N <: Bin] <: Bin =
2424
N match
2525
case Zero => One
2626
case One => One

tests/pos/9890.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ object Test {
1010

1111
type TupleMap[Tup <: Tuple, Bound, F[_ <: Bound]] <: Tuple = Tup match {
1212
case EmptyTuple => EmptyTuple
13-
case h *: t => F[h] *: TupleMap[t, Bound, F]
13+
case h *: t => h match
14+
case Bound => F[h] *: TupleMap[t, Bound, F]
1415
}
1516
type TupleDedup[Tup <: Tuple, Mask] <: Tuple = Tup match {
1617
case EmptyTuple => EmptyTuple

tests/pos/i15926.contra.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ type MT1[I <: Show[Nothing], N] = I match
55
case Int => a
66

77
val a = summon[MT1[Show[String], Int] =:= String]
8+
def b: MT1[Show[String], Int] = ""

tests/pos/i15926.extract.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ final case class Succ[+N <: Nat]() extends Nat
77

88
final case class Neg[+N <: Succ[Nat]]()
99

10-
type Sum[X, Y] = Y match
10+
type Sum[X <: Nat, Y] = Y match
1111
case Zero => X
1212
case Succ[y] => Sum[Succ[X], y]
1313

1414
type IntSum[A, B] = B match
1515
case Neg[b] => IntSumNeg[A, b]
1616

1717
type IntSumNeg[A, B] = A match
18-
case Neg[a] => Neg[Sum[a, B]]
18+
case Neg[a] => Negate[Sum[a, B]]
19+
20+
type Negate[A] = A match
21+
case Zero => Zero
22+
case Succ[_] => Neg[A]
1923

2024
type One = Succ[Zero]
2125
type Two = Succ[One]

tests/pos/i15926.min.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@ final case class Succ[+N <: Nat]() extends Nat
77

88
final case class Neg[+N <: Succ[Nat]]()
99

10-
type Sum[X, Y] = Y match
10+
type Sum[X <: Nat, Y] <: Nat = Y match
1111
case Zero => X
1212
case Succ[y] => Sum[Succ[X], y]
1313

1414
type IntSum[A, B] = B match
1515
case Neg[b] => A match
16-
case Neg[a] => Neg[Sum[a, b]]
16+
case Neg[a] => Negate[Sum[a, b]]
17+
18+
type Negate[A] = A match
19+
case Zero => Zero
20+
case Succ[_] => Neg[A]
1721

1822
type One = Succ[Zero]
1923
type Two = Succ[One]

tests/pos/i15926.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@ type NatDif[X <: NatT, Y <: NatT] <: IntT = Y match
2121
type Sum[X <: IntT, Y <: IntT] <: IntT = Y match
2222
case Zero => X
2323
case Minus[y] => X match
24-
case Minus[x] => Minus[NatSum[x, y]]
25-
case _ => NatDif[X, y]
26-
case _ => X match
24+
case Minus[x] => Negate[NatSum[x, y]]
25+
case NatT => NatDif[X, y]
26+
case NatT => X match
2727
case Minus[x] => NatDif[Y, x]
28-
case _ => NatSum[X, Y]
28+
case NatT => NatSum[X, Y]
29+
30+
type Negate[A] = A match
31+
case Zero => Zero
32+
case Succ[_] => Neg[A]

tests/pos/i16706.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ import scala.deriving.Mirror
22
import scala.reflect.ClassTag
33

44
type TupleUnionLub[T <: Tuple, Lub, Acc <: Lub] <: Lub = T match {
5-
case (h & Lub) *: t => TupleUnionLub[t, Lub, Acc | h]
5+
case h *: t => h match
6+
case Lub => TupleUnionLub[t, Lub, Acc | h]
67
case EmptyTuple => Acc
78
}
89

@@ -14,4 +15,4 @@ transparent inline given derived[A](
1415
sealed trait Foo
1516
case class FooA(a: Int) extends Foo
1617

17-
val instance = derived[Foo] // error
18+
val instance = derived[Foo] // error

0 commit comments

Comments
 (0)