Skip to content

Commit 7cadc36

Browse files
committed
Implement AssumeInfo & AssumeInfoMap
1 parent 5c2efc5 commit 7cadc36

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+503
-201
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ object desugar {
346346
// Propagate down the expected type to the leafs of the expression
347347
case Block(stats, expr) =>
348348
cpy.Block(tree)(stats, adaptToExpectedTpt(expr))
349+
case AssumeInfo(sym, info, body) =>
350+
cpy.AssumeInfo(tree)(sym, info, adaptToExpectedTpt(body))
349351
case If(cond, thenp, elsep) =>
350352
cpy.If(tree)(cond, adaptToExpectedTpt(thenp), adaptToExpectedTpt(elsep))
351353
case untpd.Parens(expr) =>
@@ -1645,6 +1647,7 @@ object desugar {
16451647
case Tuple(trees) => (pats corresponds trees)(isIrrefutable)
16461648
case Parens(rhs1) => matchesTuple(pats, rhs1)
16471649
case Block(_, rhs1) => matchesTuple(pats, rhs1)
1650+
case AssumeInfo(_, _, rhs1) => matchesTuple(pats, rhs1)
16481651
case If(_, thenp, elsep) => matchesTuple(pats, thenp) && matchesTuple(pats, elsep)
16491652
case Match(_, cases) => cases forall (matchesTuple(pats, _))
16501653
case CaseDef(_, _, rhs1) => matchesTuple(pats, rhs1)

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
330330
case If(_, thenp, elsep) => forallResults(thenp, p) && forallResults(elsep, p)
331331
case Match(_, cases) => cases forall (c => forallResults(c.body, p))
332332
case Block(_, expr) => forallResults(expr, p)
333+
case AssumeInfo(_, _, body) => forallResults(body, p)
333334
case _ => p(tree)
334335
}
335336

@@ -1088,6 +1089,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
10881089
case Typed(expr, _) => unapply(expr)
10891090
case Inlined(_, Nil, expr) => unapply(expr)
10901091
case Block(Nil, expr) => unapply(expr)
1092+
case AssumeInfo(_, _, body) => unapply(body)
10911093
case _ =>
10921094
tree.tpe.widenTermRefExpr.dealias.normalized match
10931095
case ConstantType(Constant(x)) => Some(x)

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ class TreeTypeMap(
135135
cpy.LambdaTypeTree(tdef)(tparams1, tmap1.transform(body))
136136
case inlined: Inlined =>
137137
transformInlined(inlined)
138+
case tree: AssumeInfo =>
139+
def mapBody(body: Tree) = body match
140+
case tree @ AssumeInfo(_, _, _) =>
141+
val tree1 = treeMap(tree)
142+
tree1.withType(mapType(tree1.tpe))
143+
case _ => body
144+
tree.fold(transform, mapBody) { case (assumeInfo @ AssumeInfo(sym, info, _), body) =>
145+
mapType(sym.typeRef) match
146+
case tp: TypeRef if tp eq sym.typeRef =>
147+
val sym1 = sym.subst(substFrom, substTo)
148+
val info1 = mapType(info)
149+
cpy.AssumeInfo(assumeInfo)(sym = sym1, info = info1, body = body)
150+
case _ => body // if the AssumeInfo symbol maps (as a type) to another type, we lose the associated info
151+
}
138152
case cdef @ CaseDef(pat, guard, rhs) =>
139153
val tmap = withMappedSyms(patVars(pat))
140154
val pat1 = tmap.transform(pat)

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,20 @@ object Trees {
567567
override def isTerm: Boolean = !isType // this will classify empty trees as terms, which is necessary
568568
}
569569

570+
case class AssumeInfo[+T <: Untyped] private[ast] (sym: Symbol, info: Type, body: Tree[T])(implicit @constructorOnly src: SourceFile)
571+
extends ProxyTree[T] {
572+
type ThisTree[+T <: Untyped] <: AssumeInfo[T]
573+
def forwardTo: Tree[T] = body
574+
575+
def fold[U >: T <: Untyped, A](
576+
start: Context ?=> Tree[U] => A, mapBody: Tree[U] => Tree[U] = (body: Tree[U]) => body,
577+
)(combine: Context ?=> (AssumeInfo[U], A) => A)(using Context): A =
578+
val body1 = mapBody(body)
579+
inContext(ctx.withAssumeInfo(ctx.assumeInfo.add(sym, info))) {
580+
combine(this, start(body1))
581+
}
582+
}
583+
570584
/** if cond then thenp else elsep */
571585
case class If[+T <: Untyped] private[ast] (cond: Tree[T], thenp: Tree[T], elsep: Tree[T])(implicit @constructorOnly src: SourceFile)
572586
extends TermTree[T] {
@@ -1074,6 +1088,7 @@ object Trees {
10741088
type NamedArg = Trees.NamedArg[T]
10751089
type Assign = Trees.Assign[T]
10761090
type Block = Trees.Block[T]
1091+
type AssumeInfo = Trees.AssumeInfo[T]
10771092
type If = Trees.If[T]
10781093
type InlineIf = Trees.InlineIf[T]
10791094
type Closure = Trees.Closure[T]
@@ -1212,6 +1227,9 @@ object Trees {
12121227
case tree: Block if (stats eq tree.stats) && (expr eq tree.expr) => tree
12131228
case _ => finalize(tree, untpd.Block(stats, expr)(sourceFile(tree)))
12141229
}
1230+
def AssumeInfo(tree: Tree)(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo = tree match
1231+
case tree: AssumeInfo if (sym eq tree.sym) && (info eq tree.info) && (body eq tree.body) => tree
1232+
case _ => finalize(tree, untpd.AssumeInfo(sym, info, body)(sourceFile(tree)))
12151233
def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = tree match {
12161234
case tree: If if (cond eq tree.cond) && (thenp eq tree.thenp) && (elsep eq tree.elsep) => tree
12171235
case tree: InlineIf => finalize(tree, untpd.InlineIf(cond, thenp, elsep)(sourceFile(tree)))
@@ -1344,6 +1362,8 @@ object Trees {
13441362

13451363
// Copier methods with default arguments; these demand that the original tree
13461364
// is of the same class as the copy. We only include trees with more than 2 elements here.
1365+
def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo =
1366+
AssumeInfo(tree: Tree)(sym, info, body)
13471367
def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If =
13481368
If(tree: Tree)(cond, thenp, elsep)
13491369
def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
@@ -1433,6 +1453,10 @@ object Trees {
14331453
cpy.Closure(tree)(transform(env), transform(meth), transform(tpt))
14341454
case Match(selector, cases) =>
14351455
cpy.Match(tree)(transform(selector), transformSub(cases))
1456+
case tree @ AssumeInfo(sym, info, body) =>
1457+
tree.fold(transform) { (assumeInfo, body) =>
1458+
cpy.AssumeInfo(assumeInfo)(body = body)
1459+
}
14361460
case CaseDef(pat, guard, body) =>
14371461
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
14381462
case Labeled(bind, expr) =>
@@ -1569,6 +1593,8 @@ object Trees {
15691593
this(this(this(x, env), meth), tpt)
15701594
case Match(selector, cases) =>
15711595
this(this(x, selector), cases)
1596+
case tree @ AssumeInfo(sym, info, body) =>
1597+
tree.fold(this(x, _))((_, x) => x)
15721598
case CaseDef(pat, guard, body) =>
15731599
this(this(this(x, pat), guard), body)
15741600
case Labeled(bind, expr) =>

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
9898
Block(stats, expr)
9999
}
100100

101+
def AssumeInfo(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo =
102+
ta.assignType(untpd.AssumeInfo(sym, info, body), body)
103+
101104
def If(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If =
102105
ta.assignType(untpd.If(cond, thenp, elsep), thenp, elsep)
103106

@@ -683,6 +686,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
683686
}
684687
}
685688

689+
override def AssumeInfo(tree: Tree)(sym: Symbol, info: Type, body: Tree)(using Context): AssumeInfo =
690+
val tree1 = untpdCpy.AssumeInfo(tree)(sym, info, body)
691+
tree match
692+
case tree: AssumeInfo if body.tpe eq tree.body.tpe => tree1.withTypeUnchecked(tree.tpe)
693+
case _ => ta.assignType(tree1, body)
694+
686695
override def If(tree: Tree)(cond: Tree, thenp: Tree, elsep: Tree)(using Context): If = {
687696
val tree1 = untpdCpy.If(tree)(cond, thenp, elsep)
688697
tree match {
@@ -767,6 +776,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
767776
}
768777
}
769778

779+
override def AssumeInfo(tree: AssumeInfo)(sym: Symbol = tree.sym, info: Type = tree.info, body: Tree = tree.body)(using Context): AssumeInfo =
780+
AssumeInfo(tree: Tree)(sym, info, body)
770781
override def If(tree: If)(cond: Tree = tree.cond, thenp: Tree = tree.thenp, elsep: Tree = tree.elsep)(using Context): If =
771782
If(tree: Tree)(cond, thenp, elsep)
772783
override def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
388388
def NamedArg(name: Name, arg: Tree)(implicit src: SourceFile): NamedArg = new NamedArg(name, arg)
389389
def Assign(lhs: Tree, rhs: Tree)(implicit src: SourceFile): Assign = new Assign(lhs, rhs)
390390
def Block(stats: List[Tree], expr: Tree)(implicit src: SourceFile): Block = new Block(stats, expr)
391+
def AssumeInfo(sym: Symbol, info: Type, body: Tree)(implicit src: SourceFile): AssumeInfo = new AssumeInfo(sym, info, body)
391392
def If(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new If(cond, thenp, elsep)
392393
def InlineIf(cond: Tree, thenp: Tree, elsep: Tree)(implicit src: SourceFile): If = new InlineIf(cond, thenp, elsep)
393394
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package dotty.tools
2+
package dotc
3+
package core
4+
5+
import Contexts.*, Decorators.*, NameKinds.*, Symbols.*, Types.*
6+
import ast.*, Trees.*
7+
import printing.*, Texts.*
8+
9+
import scala.annotation.internal.sharable
10+
import util.{SimpleIdentitySet, SimpleIdentityMap}
11+
12+
object AssumeInfoMap:
13+
@sharable val empty: AssumeInfoMap = AssumeInfoMap(SimpleIdentityMap.empty)
14+
15+
class AssumeInfoMap private (
16+
private val map: SimpleIdentityMap[Symbol, Type],
17+
) extends Showable:
18+
def info(sym: Symbol)(using Context): Type | Null = map(sym)
19+
20+
def add(sym: Symbol, info: Type) = new AssumeInfoMap(map.updated(sym, info))
21+
22+
override def toText(p: Printer): Text =
23+
given Context = p match
24+
case p: PlainPrinter => p.printerContext
25+
case _ => Contexts.NoContext
26+
val deps = for (sym, info) <- map.toList yield
27+
(p.toText(sym.typeRef) ~ p.toText(info)).close
28+
("AssumeInfo(" ~ Text(deps, ", ") ~ ")").close

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ object Contexts {
143143
def typerState: TyperState
144144
def gadt: GadtConstraint = gadtState.gadt
145145
def gadtState: GadtState
146+
def assumeInfo: AssumeInfoMap
146147
def searchHistory: SearchHistory
147148
def source: SourceFile
148149

@@ -470,6 +471,15 @@ object Contexts {
470471
case None => fresh.dropProperty(key)
471472
}
472473

474+
final def withGadt(gadt: GadtConstraint): Context =
475+
if this.gadt eq gadt then this else fresh.setGadtState(GadtState(gadt))
476+
477+
final def withGadtState(gadt: GadtState): Context =
478+
if this.gadtState eq gadt then this else fresh.setGadtState(gadt)
479+
480+
final def withAssumeInfo(assumeInfo: AssumeInfoMap): Context =
481+
if this.assumeInfo eq assumeInfo then this else fresh.setAssumeInfo(assumeInfo)
482+
473483
def typer: Typer = this.typeAssigner match {
474484
case typer: Typer => typer
475485
case _ => new Typer
@@ -545,6 +555,9 @@ object Contexts {
545555
private var _gadtState: GadtState = uninitialized
546556
final def gadtState: GadtState = _gadtState
547557

558+
private var _assumeInfo: AssumeInfoMap = uninitialized
559+
final def assumeInfo: AssumeInfoMap = _assumeInfo
560+
548561
private var _searchHistory: SearchHistory = uninitialized
549562
final def searchHistory: SearchHistory = _searchHistory
550563

@@ -569,6 +582,7 @@ object Contexts {
569582
_tree = origin.tree
570583
_scope = origin.scope
571584
_gadtState = origin.gadtState
585+
_assumeInfo = origin.assumeInfo
572586
_searchHistory = origin.searchHistory
573587
_source = origin.source
574588
_moreProperties = origin.moreProperties
@@ -632,6 +646,10 @@ object Contexts {
632646
def setFreshGADTBounds: this.type =
633647
setGadtState(gadtState.fresh)
634648

649+
def setAssumeInfo(assumeInfo: AssumeInfoMap): this.type =
650+
this._assumeInfo= assumeInfo
651+
this
652+
635653
def setSearchHistory(searchHistory: SearchHistory): this.type =
636654
util.Stats.record("Context.setSearchHistory")
637655
this._searchHistory = searchHistory
@@ -723,6 +741,7 @@ object Contexts {
723741
.updated(compilationUnitLoc, NoCompilationUnit)
724742
c._searchHistory = new SearchRoot
725743
c._gadtState = GadtState(GadtConstraint.empty)
744+
c._assumeInfo = AssumeInfoMap.empty
726745
c
727746
end FreshContext
728747

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package core
44

55
import Contexts.*, Decorators.*, Symbols.*, Types.*
66
import NameKinds.UniqueName
7+
import ast.*, Trees.*
78
import config.Printers.{gadts, gadtsConstr}
89
import util.{SimpleIdentitySet, SimpleIdentityMap}
910
import printing._
@@ -27,6 +28,7 @@ class GadtConstraint private (
2728
def symbols: List[Symbol] = mapping.keys
2829
def withConstraint(c: Constraint) = copy(myConstraint = c)
2930
def withWasConstrained = copy(wasConstrained = true)
31+
def isEmpty: Boolean = mapping.isEmpty
3032

3133
def add(sym: Symbol, tv: TypeVar): GadtConstraint = copy(
3234
mapping = mapping.updated(sym, tv),
@@ -136,6 +138,13 @@ class GadtConstraint private (
136138

137139
override def toText(printer: Printer): Texts.Text = printer.toText(this)
138140

141+
def eql(that: GadtConstraint): Boolean = (this eq that) || {
142+
myConstraint == that.myConstraint
143+
&& mapping == that.mapping
144+
&& reverseMapping == that.reverseMapping
145+
&& wasConstrained == that.wasConstrained
146+
}
147+
139148
/** Provides more information than toText, by showing the underlying Constraint details. */
140149
def debugBoundsDescription(using Context): String = i"$this\n$constraint"
141150

@@ -166,7 +175,8 @@ sealed trait GadtState {
166175
* @see [[ConstraintHandling.addToConstraint]]
167176
*/
168177
def addToConstraint(sym: Symbol)(using Context): Boolean = addToConstraint(sym :: Nil)
169-
def addToConstraint(params: List[Symbol])(using Context): Boolean = {
178+
def addToConstraint(syms: List[Symbol])(using Context): Boolean = addToConstraint(syms, ctx.nestingLevel)
179+
def addToConstraint(params: List[Symbol], nestingLevel: Int)(using Context): Boolean = {
170180
import NameKinds.DepParamName
171181

172182
val poly1 = PolyType(params.map { sym => DepParamName.fresh(sym.name.toTypeName) })(
@@ -201,7 +211,7 @@ sealed trait GadtState {
201211
)
202212

203213
val tvars = params.lazyZip(poly1.paramRefs).map { (sym, paramRef) =>
204-
val tv = TypeVar(paramRef, creatorState = null)
214+
val tv = TypeVar(paramRef, creatorState = null, nestingLevel)
205215
gadt = gadt.add(sym, tv)
206216
tv
207217
}
@@ -277,6 +287,8 @@ sealed trait GadtState {
277287
override def fullLowerBound(param: TypeParamRef)(using Context): Type = gadt.fullLowerBound(param)
278288
override def fullUpperBound(param: TypeParamRef)(using Context): Type = gadt.fullUpperBound(param)
279289

290+
def symbols: List[Symbol] = gadt.symbols
291+
280292
// ---- Debug ------------------------------------------------------------
281293

282294
override def constr = gadtsConstr

0 commit comments

Comments
 (0)