Skip to content

Commit 5bdee62

Browse files
committed
Cache results of attempts to reduce match types
1 parent d17c292 commit 5bdee62

File tree

4 files changed

+74
-21
lines changed

4 files changed

+74
-21
lines changed

compiler/src/dotty/tools/dotc/config/Config.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ object Config {
66
final val cacheAsSeenFrom = true
77
final val cacheMemberNames = true
88
final val cacheImplicitScopes = true
9+
final val cacheMatchReduced = true
910

1011
final val checkCacheMembersNamed = false
1112

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ trait ConstraintHandling {
3838
/** Potentially a type lambda that is still instantiatable, even though the constraint
3939
* is generally frozen.
4040
*/
41-
protected var unfrozen: Type = NoType
41+
protected var caseLambda: Type = NoType
4242

4343
/** If set, align arguments `S1`, `S2`when taking the glb
4444
* `T1 { X = S1 } & T2 { X = S2 }` of a constraint upper bound for some type parameter.
@@ -52,7 +52,7 @@ trait ConstraintHandling {
5252
*/
5353
protected var comparedTypeLambdas: Set[TypeLambda] = Set.empty
5454

55-
private def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean): Boolean =
55+
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean): Boolean =
5656
!constraint.contains(param) || {
5757
def occursIn(bound: Type): Boolean = {
5858
val b = bound.dealias
@@ -174,13 +174,13 @@ trait ConstraintHandling {
174174

175175
@forceInline final def inFrozenConstraint[T](op: => T): T = {
176176
val savedFrozen = frozenConstraint
177-
val savedUnfrozen = unfrozen
177+
val savedLambda = caseLambda
178178
frozenConstraint = true
179-
unfrozen = NoType
179+
caseLambda = NoType
180180
try op
181181
finally {
182182
frozenConstraint = savedFrozen
183-
unfrozen = savedUnfrozen
183+
caseLambda = savedLambda
184184
}
185185
}
186186

@@ -325,7 +325,7 @@ trait ConstraintHandling {
325325
}
326326

327327
/** The current bounds of type parameter `param` */
328-
final def bounds(param: TypeParamRef): TypeBounds = {
328+
def bounds(param: TypeParamRef): TypeBounds = {
329329
val e = constraint.entry(param)
330330
if (e.exists) e.bounds
331331
else {
@@ -361,7 +361,7 @@ trait ConstraintHandling {
361361

362362
/** Can `param` be constrained with new bounds? */
363363
final def canConstrain(param: TypeParamRef): Boolean =
364-
(!frozenConstraint || (unfrozen `eq` param.binder)) && constraint.contains(param)
364+
(!frozenConstraint || (caseLambda `eq` param.binder)) && constraint.contains(param)
365365

366366
/** Add constraint `param <: bound` if `fromBelow` is false, `param >: bound` otherwise.
367367
* `bound` is assumed to be in normalized form, as specified in `firstTry` and

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
103103
true
104104
}
105105

106+
protected def gadtBounds(sym: Symbol)(implicit ctx: Context) = ctx.gadt.bounds(sym)
107+
protected def gadtSetBounds(sym: Symbol, b: TypeBounds) = ctx.gadt.setBounds(sym, b)
108+
106109
// Subtype testing `<:<`
107110

108111
def topLevelSubType(tp1: Type, tp2: Type): Boolean = {
@@ -375,7 +378,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
375378
def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match {
376379
case info2: TypeBounds =>
377380
def compareGADT: Boolean = {
378-
val gbounds2 = ctx.gadt.bounds(tp2.symbol)
381+
val gbounds2 = gadtBounds(tp2.symbol)
379382
(gbounds2 != null) &&
380383
(isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
381384
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
@@ -601,7 +604,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
601604
tp1.info match {
602605
case TypeBounds(_, hi1) =>
603606
def compareGADT = {
604-
val gbounds1 = ctx.gadt.bounds(tp1.symbol)
607+
val gbounds1 = gadtBounds(tp1.symbol)
605608
(gbounds1 != null) &&
606609
(isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
607610
narrowGADTBounds(tp1, tp2, approx, isUpper = true)) &&
@@ -1146,12 +1149,12 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
11461149
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
11471150
if (bound.isRef(tparam)) false
11481151
else {
1149-
val oldBounds = ctx.gadt.bounds(tparam)
1152+
val oldBounds = gadtBounds(tparam)
11501153
val newBounds =
11511154
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
11521155
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
11531156
isSubType(newBounds.lo, newBounds.hi) &&
1154-
{ ctx.gadt.setBounds(tparam, newBounds); true }
1157+
{ gadtSetBounds(tparam, newBounds); true }
11551158
}
11561159
}
11571160
}
@@ -1737,11 +1740,33 @@ object TypeComparer {
17371740
class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
17381741
import state.constraint
17391742

1743+
val footprint = mutable.Set[Type]()
1744+
1745+
override def bounds(param: TypeParamRef): TypeBounds = {
1746+
if (param.binder `ne` caseLambda) footprint += param
1747+
super.bounds(param)
1748+
}
1749+
1750+
override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean): Boolean = {
1751+
if (param.binder `ne` caseLambda) footprint += param
1752+
super.addOneBound(param, bound, isUpper)
1753+
}
1754+
1755+
override def gadtBounds(sym: Symbol)(implicit ctx: Context) = {
1756+
footprint += sym.typeRef
1757+
super.gadtBounds(sym)
1758+
}
1759+
1760+
override def gadtSetBounds(sym: Symbol, b: TypeBounds) = {
1761+
footprint += sym.typeRef
1762+
super.gadtSetBounds(sym, b)
1763+
}
1764+
17401765
def matchCase(scrut: Type, cas: Type, instantiate: Boolean)(implicit ctx: Context): Type = {
17411766

17421767
def paramInstances = new TypeAccumulator[Array[Type]] {
17431768
def apply(inst: Array[Type], t: Type) = t match {
1744-
case t @ TypeParamRef(b, n) if b `eq` unfrozen =>
1769+
case t @ TypeParamRef(b, n) if b `eq` caseLambda =>
17451770
inst(n) = instanceType(t, fromBelow = variance >= 0)
17461771
inst
17471772
case _ =>
@@ -1751,7 +1776,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
17511776

17521777
def instantiateParams(inst: Array[Type]) = new TypeMap {
17531778
def apply(t: Type) = t match {
1754-
case t @ TypeParamRef(b, n) if b `eq` unfrozen => inst(n)
1779+
case t @ TypeParamRef(b, n) if b `eq` caseLambda => inst(n)
17551780
case t: LazyRef => apply(t.ref)
17561781
case _ => mapOver(t)
17571782
}
@@ -1762,16 +1787,16 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
17621787
inFrozenConstraint {
17631788
val cas1 = cas match {
17641789
case cas: HKTypeLambda =>
1765-
unfrozen = constrained(cas)
1766-
unfrozen.resultType
1790+
caseLambda = constrained(cas)
1791+
caseLambda.resultType
17671792
case _ =>
17681793
cas
17691794
}
17701795
val defn.FunctionOf(pat :: Nil, body, _, _) = cas1
17711796
if (isSubType(scrut, pat))
1772-
unfrozen match {
1773-
case unfrozen: HKTypeLambda if instantiate =>
1774-
val instances = paramInstances(new Array(unfrozen.paramNames.length), pat)
1797+
caseLambda match {
1798+
case caseLambda: HKTypeLambda if instantiate =>
1799+
val instances = paramInstances(new Array(caseLambda.paramNames.length), pat)
17751800
instantiateParams(instances)(body)
17761801
case _ =>
17771802
body

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3545,7 +3545,7 @@ object Types {
35453545

35463546
def alternatives(implicit ctx: Context): List[Type] = cases.map(caseType)
35473547

3548-
private var myUnderlying: Type = null
3548+
private[this] var myUnderlying: Type = null
35493549

35503550
def underlying(implicit ctx: Context): Type = {
35513551
if (myUnderlying == null) myUnderlying = alternatives.reduceLeft(OrType(_, _))
@@ -3564,16 +3564,20 @@ object Types {
35643564
}
35653565
}
35663566

3567-
private var myApproxScrut: Type = null
3567+
private[this] var myApproxScrut: Type = null
35683568

35693569
def approximatedScrutinee(implicit ctx: Context): Type = {
35703570
if (myApproxScrut == null) myApproxScrut = wildApproxMap.apply(scrutinee)
35713571
myApproxScrut
35723572
}
35733573

3574+
private[this] var myReduced: Type = null
3575+
private[this] var reductionContext: mutable.Map[Type, TypeBounds] = null
3576+
35743577
def reduced(implicit ctx: Context): Type = {
35753578
val trackingCtx = ctx.fresh.setTypeComparerFn(new TrackingTypeComparer(_))
35763579
val cmp = trackingCtx.typeComparer.asInstanceOf[TrackingTypeComparer]
3580+
35773581
def recur(cases: List[Type])(implicit ctx: Context): Type = cases match {
35783582
case Nil => NoType
35793583
case cas :: cases1 =>
@@ -3582,7 +3586,30 @@ object Types {
35823586
else if (cmp.matchCase(approximatedScrutinee, cas, instantiate = false).exists) NoType
35833587
else recur(cases1)
35843588
}
3585-
recur(cases)(trackingCtx)
3589+
3590+
def contextBounds(tp: Type): TypeBounds = tp match {
3591+
case tp: TypeParamRef => ctx.typerState.constraint.fullBounds(tp)
3592+
case tp: TypeRef => ctx.gadt.bounds(tp.symbol)
3593+
}
3594+
3595+
def updateReductionContext() = {
3596+
reductionContext = new mutable.HashMap
3597+
for (tp <- cmp.footprint) reductionContext(tp) = contextBounds(tp)
3598+
}
3599+
3600+
def upToDate =
3601+
cmp.footprint.forall { tp =>
3602+
reductionContext.get(tp) match {
3603+
case Some(bounds) => bounds `eq` contextBounds(tp)
3604+
case None => false
3605+
}
3606+
}
3607+
3608+
if (!Config.cacheMatchReduced || myReduced == null || !upToDate) {
3609+
myReduced = recur(cases)(trackingCtx)
3610+
updateReductionContext()
3611+
}
3612+
myReduced
35863613
}
35873614
}
35883615

0 commit comments

Comments
 (0)