Skip to content

Commit 0ab760b

Browse files
committed
Implement GadtExpr
1 parent f0be00d commit 0ab760b

39 files changed

+238
-44
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ object desugar {
336336
// Propagate down the expected type to the leafs of the expression
337337
case Block(stats, expr) =>
338338
cpy.Block(tree)(stats, adaptToExpectedTpt(expr))
339+
case GadtExpr(gadt, expr) =>
340+
cpy.GadtExpr(tree)(gadt, adaptToExpectedTpt(expr))
339341
case If(cond, thenp, elsep) =>
340342
cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep))
341343
case untpd.Parens(expr) =>
@@ -1631,6 +1633,7 @@ object desugar {
16311633
case Tuple(trees) => (pats corresponds trees)(isIrrefutable)
16321634
case Parens(rhs1) => matchesTuple(pats, rhs1)
16331635
case Block(_, rhs1) => matchesTuple(pats, rhs1)
1636+
case GadtExpr(_, rhs1) => matchesTuple(pats, rhs1)
16341637
case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep)
16351638
case Match(_, cases) => cases forall (matchesTuple(pats, _))
16361639
case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1)

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ trait TreeInfo[T >: Untyped <: Type] { self: Trees.Instance[T] =>
298298
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
299299
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
300300
case Block(_, expr) => forallResults(expr, p)
301+
case GadtExpr(_, expr) => forallResults(expr, p)
301302
case _ => p(tree)
302303
}
303304
}
@@ -1012,6 +1013,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
10121013
case Typed(expr, _) => unapply(expr)
10131014
case Inlined(_, Nil, expr) => unapply(expr)
10141015
case Block(Nil, expr) => unapply(expr)
1016+
case GadtExpr(_, expr) => unapply(expr)
10151017
case _ =>
10161018
tree.tpe.widenTermRefExpr.normalized match
10171019
case ConstantType(Constant(x)) => Some(x)

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,16 @@ class TreeTypeMap(
121121
val (tmap1, stats1) = transformDefs(stats)
122122
val expr1 = tmap1.transform(expr)
123123
cpy.Block(blk)(stats1, expr1)
124+
case GadtExpr(gadt, expr) =>
125+
val tmap = withMappedSyms(gadt.symbols.diff(substFrom)) // CaseDef handles the patVars
126+
val gadt1 = EmptyGadtConstraint.fresh
127+
for sym <- gadt.symbols do
128+
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
129+
val sym1 = mapOwner(sym)
130+
gadt1.addToConstraint(sym1)
131+
gadt1.addBound(sym1, lo, false)
132+
gadt1.addBound(sym1, hi, true)
133+
inContext(ctx.fresh.setGadt(gadt1))(cpy.GadtExpr(tree)(gadt1, tmap.transform(expr)))
124134
case inlined: Inlined =>
125135
transformInlined(inlined)
126136
case cdef @ CaseDef(pat, guard, rhs) =>

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,12 @@ object Trees {
575575
override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary
576576
}
577577

578+
case class GadtExpr[-T >: Untyped] private[ast] (gadt: GadtConstraint, expr: Tree[T])(implicit @constructorOnly src: SourceFile)
579+
extends ProxyTree[T] {
580+
type ThisTree[-T >: Untyped] <: GadtExpr[T]
581+
def forwardTo: Tree[T] = expr
582+
}
583+
578584
/** if cond then thenp else elsep */
579585
case class If[-T >: Untyped] private[ast] (cond: Tree[T], thenp: Tree[T], elsep: Tree[T])(implicit @constructorOnly src: SourceFile)
580586
extends TermTree[T] {
@@ -1077,6 +1083,7 @@ object Trees {
10771083
type NamedArg = Trees.NamedArg[T]
10781084
type Assign = Trees.Assign[T]
10791085
type Block = Trees.Block[T]
1086+
type GadtExpr = Trees.GadtExpr[T]
10801087
type If = Trees.If[T]
10811088
type InlineIf = Trees.InlineIf[T]
10821089
type Closure = Trees.Closure[T]
@@ -1215,6 +1222,9 @@ object Trees {
12151222
case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree
12161223
case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree)))
12171224
}
1225+
def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr = tree match
1226+
case tree: GadtExpr if (gadt eq tree.gadt) && (expr eq tree.expr) => tree
1227+
case _ => finalize(tree, untpd.GadtExpr(gadt, expr)(sourceFile(tree)))
12181228
def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match {
12191229
case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree
12201230
case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree)))
@@ -1430,6 +1440,8 @@ object Trees {
14301440
cpy.Assign(tree)(transform(lhs), transform(rhs))
14311441
case blk: Block =>
14321442
transformBlock(blk)
1443+
case GadtExpr(gadt, expr) =>
1444+
inContext(ctx.fresh.setGadt(gadt))(cpy.GadtExpr(tree)(gadt, transform(expr)))
14331445
case If(cond, thenp, elsep) =>
14341446
cpy.If(tree)(transform(cond), transform(thenp), transform(elsep))
14351447
case Closure(env, meth, tpt) =>
@@ -1566,6 +1578,8 @@ object Trees {
15661578
this(this(x, lhs), rhs)
15671579
case Block(stats, expr) =>
15681580
this(this(x, stats), expr)
1581+
case GadtExpr(gadt, expr) =>
1582+
inContext(ctx.fresh.setGadt(gadt))(this(x, expr))
15691583
case If(cond, thenp, elsep) =>
15701584
this(this(this(x, cond), thenp), elsep)
15711585
case Closure(env, meth, tpt) =>

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
9292
Block(stats, expr)
9393
}
9494

95+
def GadtExpr(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
96+
ta.assignType(untpd.GadtExpr(gadt, expr), gadt, expr)
97+
9598
def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If =
9699
ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep)
97100

@@ -673,6 +676,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
673676
}
674677
}
675678

679+
override def GadtExpr(tree: Tree)(gadt: GadtConstraint, expr: Tree)(using Context): GadtExpr =
680+
val tree1 = untpdCpy.GadtExpr(tree)(gadt, expr)
681+
tree match
682+
case tree: GadtExpr if expr.tpe eq tree.expr.tpe => tree1.withTypeUnchecked(tree.tpe)
683+
case _ => ta.assignType(tree1, gadt, expr)
684+
676685
override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = {
677686
val tree1 = untpdCpy.If(tree)(cond, thenp, elsep)
678687
tree match {

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
376376
def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg)
377377
def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs)
378378
def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr)
379+
def GadtExpr(gadt: GadtConstraint, expr: Tree)(implicit src: SourceFile): GadtExpr = new GadtExpr(gadt, expr)
379380
def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep)
380381
def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep)
381382
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,8 @@ trait ConstraintHandling {
363363

364364
val level1 = nestingLevel(p1)
365365
val level2 = nestingLevel(p2)
366-
val pKept = if level1 <= level2 then p1 else p2
367-
val pRemoved = if level1 <= level2 then p2 else p1
366+
val pKept = if level1 < level2 then p1 else p2
367+
val pRemoved = if level1 < level2 then p2 else p1
368368

369369
val down = constraint.exclusiveLower(p2, p1)
370370
val up = constraint.exclusiveUpper(p1, p2)

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ sealed abstract class GadtConstraint extends Showable {
3131
*
3232
* @see [[ConstraintHandling.addToConstraint]]
3333
*/
34-
def addToConstraint(syms: List[Symbol])(using Context): Boolean
34+
def addToConstraint(syms: List[Symbol], nestingLevel: Int)(using Context): Boolean
35+
def addToConstraint(syms: List[Symbol])(using Context): Boolean = addToConstraint(syms, ctx.nestingLevel)
3536
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
3637

3738
/** Further constrain a symbol already present in the constraint. */
@@ -50,13 +51,14 @@ sealed abstract class GadtConstraint extends Showable {
5051
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
5152

5253
def symbols: List[Symbol]
54+
def inputs: List[(List[Symbol], Int)]
5355

5456
def fresh: GadtConstraint
5557

5658
/** Restore the state from other [[GadtConstraint]], probably copied using [[fresh]] */
5759
def restore(other: GadtConstraint): Unit
5860

59-
def debugBoundsDescription(using Context): String
61+
def eql(that: GadtConstraint): Boolean
6062
}
6163

6264
final class ProperGadtConstraint private(
@@ -88,7 +90,7 @@ final class ProperGadtConstraint private(
8890
// the case where they're valid, so no approximating is needed.
8991
rawBound
9092

91-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = {
93+
override def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = {
9294
import NameKinds.DepParamName
9395

9496
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
@@ -126,15 +128,15 @@ final class ProperGadtConstraint private(
126128
)
127129

128130
val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
129-
val tv = TypeVar(paramRef, creatorState = null)
131+
val tv = TypeVar(paramRef, creatorState = null, nestingLevel)
130132
mapping = mapping.updated(sym, tv)
131133
reverseMapping = reverseMapping.updated(tv.origin, sym)
132134
tv
133135
}
134136

135137
// The replaced symbols are picked up here.
136138
addToConstraint(poly1, tvars)
137-
.showing(i"added to constraint: [$poly1] $params%, %\n$debugBoundsDescription", gadts)
139+
.showing(i"added to constraint: [$poly1] $params%, % gadt = $this", gadts)
138140
}
139141

140142
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = {
@@ -221,6 +223,11 @@ final class ProperGadtConstraint private(
221223

222224
override def symbols: List[Symbol] = mapping.keys
223225

226+
override def inputs: List[(List[Symbol], Int)] =
227+
for tl <- constraint.domainLambdas yield
228+
val syms = tl.paramRefs.map(reverseMapping(_).nn)
229+
(syms, mapping(syms.head).nn.nestingLevel)
230+
224231
override def fresh: GadtConstraint = new ProperGadtConstraint(
225232
myConstraint,
226233
mapping,
@@ -291,17 +298,15 @@ final class ProperGadtConstraint private(
291298

292299
override def constr = gadtsConstr
293300

294-
override def toText(printer: Printer): Texts.Text = constraint.toText(printer)
301+
override def eql(that: GadtConstraint): Boolean = (this eq that) || that.match
302+
case that: ProperGadtConstraint =>
303+
myConstraint == that.myConstraint
304+
&& mapping == that.mapping
305+
&& reverseMapping == that.reverseMapping
306+
&& wasConstrained == that.wasConstrained
307+
case _ => false
295308

296-
override def debugBoundsDescription(using Context): String = {
297-
val sb = new mutable.StringBuilder
298-
sb ++= constraint.show
299-
sb += '\n'
300-
mapping.foreachBinding { case (sym, _) =>
301-
sb ++= i"$sym: ${fullBounds(sym)}\n"
302-
}
303-
sb.result
304-
}
309+
override def toText(printer: Printer): Texts.Text = printer.toText(this)
305310
}
306311

307312
@sharable object EmptyGadtConstraint extends GadtConstraint {
@@ -314,18 +319,19 @@ final class ProperGadtConstraint private(
314319

315320
override def contains(sym: Symbol)(using Context) = false
316321

317-
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
322+
override def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")
318323
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(using Context): Boolean = unsupported("EmptyGadtConstraint.addBound")
319324

320325
override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = unsupported("EmptyGadtConstraint.approximation")
321326

322327
override def symbols: List[Symbol] = Nil
328+
override def inputs: List[(List[Symbol], Int)] = Nil
323329

324330
override def fresh = new ProperGadtConstraint
325331
override def restore(other: GadtConstraint): Unit =
326332
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
327333

328-
override def debugBoundsDescription(using Context): String = "EmptyGadtConstraint"
334+
override def eql(that: GadtConstraint): Boolean = (this eq that) || that == EmptyGadtConstraint
329335

330-
override def toText(printer: Printer): Texts.Text = "EmptyGadtConstraint"
336+
override def toText(printer: Printer): Texts.Text = printer.toText(this)
331337
}

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
261261
val assumeInvariantRefinement =
262262
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
263263

264-
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
264+
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res gadt = ${ctx.gadt}") {
265265
(tp, pt) match {
266266
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
267267
val saved = state.nn.constraint

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3120,7 +3120,7 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
31203120
}
31213121

31223122
override def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean =
3123-
traceIndented(s"add GADT constraint ${show(sym)} ${if isUpper then "<:" else ">:"} ${show(b)} $frozenNotice, GADT constraint = ${show(ctx.gadt.debugBoundsDescription)}") {
3123+
traceIndented(s"add GADT constraint ${show(sym)} ${if isUpper then "<:" else ">:"} ${show(b)} $frozenNotice, GADT constraint = ${show(ctx.gadt)}") {
31243124
super.gadtAddBound(sym, b, isUpper)
31253125
}
31263126

compiler/src/dotty/tools/dotc/core/tasty/TastyPrinter.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ class TastyPrinter(bytes: Array[Byte]) {
109109
val length = treeStr("%5d".format(index(currentAddr) - index(startAddr)))
110110
sb.append(s"\n $length:" + " " * indent)
111111
}
112+
def printInt() = sb.append(treeStr(" " + readInt()))
112113
def printNat() = sb.append(treeStr(" " + readNat()))
113114
def printName() = {
114115
val idx = readNat()
@@ -139,6 +140,8 @@ class TastyPrinter(bytes: Array[Byte]) {
139140
printTrees()
140141
case PARAMtype =>
141142
printNat(); printNat()
143+
case CONSTRAINT =>
144+
printInt(); until(end) { printNat(); printTree(); printTree() }
142145
case _ =>
143146
printTrees()
144147
}

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,21 @@ class TreePickler(pickler: TastyPickler) {
472472
writeByte(BLOCK)
473473
stats.foreach(preRegister)
474474
withLength { pickleTree(expr); stats.foreach(pickleTree) }
475+
case GadtExpr(gadt, expr) =>
476+
writeByte(GADTEXPR)
477+
withLength {
478+
pickleTree(expr)
479+
for (symbols, nestingLevel) <- gadt.inputs do
480+
writeByte(CONSTRAINT)
481+
withLength {
482+
writeInt(nestingLevel)
483+
for sym <- symbols do
484+
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
485+
pickleSymRef(sym)
486+
pickleType(lo)
487+
pickleType(hi)
488+
}
489+
}
475490
case tree @ If(cond, thenp, elsep) =>
476491
writeByte(IF)
477492
withLength {

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class TreeUnpickler(reader: TastyReader,
167167
val start = currentAddr
168168
val tag = readByte()
169169
tag match {
170-
case VALDEF | DEFDEF | TYPEDEF | TYPEPARAM | PARAM | TEMPLATE =>
170+
case VALDEF | DEFDEF | TYPEDEF | TYPEPARAM | PARAM | TEMPLATE | CONSTRAINT =>
171171
val end = readEnd()
172172
for (i <- 0 until numRefs(tag)) readNat()
173173
if (tag == TEMPLATE) {
@@ -1224,6 +1224,25 @@ class TreeUnpickler(reader: TastyReader,
12241224
skipTree()
12251225
readStats(ctx.owner, end,
12261226
(stats, ctx) => Block(stats, exprReader.readTerm()(using ctx)))
1227+
case GADTEXPR =>
1228+
val gadt = ctx.gadt.fresh
1229+
val expr = readTerm()
1230+
until(end) {
1231+
readByte()
1232+
val end = readEnd()
1233+
val nestingLevel = readInt()
1234+
val constraints = until(end)((readSymRef(), readType(), readType()))
1235+
gadt.addToConstraint(constraints.map(_._1), nestingLevel)
1236+
def addBound(tp1: Type, tp2: Type) = (tp1, tp2) match
1237+
case (_, tp2: TypeRef) if gadt.contains(tp2.symbol) => gadt.addBound(tp2.symbol, tp1, isUpper = false)
1238+
case (tp1: TypeRef, _) => gadt.addBound(tp1.symbol, tp2, isUpper = true)
1239+
case x => unreachable(x)
1240+
for (sym, lo, hi) <- constraints do
1241+
val tp = sym.typeRef
1242+
addBound(tp, hi)
1243+
addBound(lo, tp)
1244+
}
1245+
GadtExpr(gadt, expr)
12271246
case INLINED =>
12281247
val exprReader = fork
12291248
skipTree()

compiler/src/dotty/tools/dotc/inlines/Inliner.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,10 @@ class Inliner(val call: tpd.Tree)(using Context):
847847
}
848848
val selType = if (sel.isEmpty) wideSelType else selTyped(sel)
849849
reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match {
850-
case Some((caseBindings, rhs0)) =>
850+
case Some((caseBindings, rhs9)) =>
851+
val rhs0 = rhs9 match
852+
case GadtExpr(_, expr) => expr
853+
case _ => rhs9
851854
// drop type ascriptions/casts hiding pattern-bound types (which are now aliases after reducing the match)
852855
// note that any actually necessary casts will be reinserted by the typing pass below
853856
val rhs1 = rhs0 match {

compiler/src/dotty/tools/dotc/inlines/Inlines.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ object Inlines:
5353
/** Should call be inlined in this context? */
5454
def needsInlining(tree: Tree)(using Context): Boolean = tree match {
5555
case Block(_, expr) => needsInlining(expr)
56+
//case GadtExpr(_, expr) => needsInlining(expr) // breaks tests/pos/gadt-infer-ascription.scala
5657
case _ =>
5758
isInlineable(tree.symbol)
5859
&& !tree.tpe.widenTermRefExpr.isInstanceOf[MethodOrPoly]
@@ -113,6 +114,8 @@ object Inlines:
113114
case Block(stats, expr) =>
114115
bindings ++= stats.map(liftPos)
115116
liftBindings(expr, liftPos)
117+
case GadtExpr(_, expr) =>
118+
liftBindings(expr, liftPos)
116119
case Inlined(call, stats, expr) =>
117120
bindings ++= stats.map(liftPos)
118121
val lifter = liftFromInlined(call)

0 commit comments

Comments
 (0)