Skip to content

Commit 804ac94

Browse files
committed
Optimize TypeMap and TypeAccumulator
1 parent 5e0a433 commit 804ac94

File tree

5 files changed

+100
-70
lines changed

5 files changed

+100
-70
lines changed

compiler/src/dotty/tools/dotc/core/Decorators.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ object Decorators {
171171
def & (ys: List[T]): List[T] = xs filter (ys contains _)
172172
}
173173

174-
extension [T, U](xss: List[List[T]]):
174+
extension [T, U](xss: List[List[T]])
175175
def nestedMap(f: T => U): List[List[U]] =
176176
xss.map(_.map(f))
177177
def nestedMapConserve(f: T => U): List[List[U]] =
@@ -180,14 +180,14 @@ object Decorators {
180180
xss.zipWithConserve(yss)((xs, ys) => xs.zipWithConserve(ys)(f))
181181
end extension
182182

183-
extension (text: Text):
183+
extension (text: Text)
184184
def show(using Context): String = text.mkString(ctx.settings.pageWidth.value, ctx.settings.printLines.value)
185185

186186
/** Test whether a list of strings representing phases contains
187187
* a given phase. See [[config.CompilerCommand#explainAdvanced]] for the
188188
* exact meaning of "contains" here.
189189
*/
190-
extension (names: List[String]) {
190+
extension (names: List[String])
191191
def containsPhase(phase: Phase): Boolean =
192192
names.nonEmpty && {
193193
phase match {
@@ -203,18 +203,16 @@ object Decorators {
203203
}
204204
}
205205
}
206-
}
207206

208-
extension [T](x: T) {
207+
extension [T](x: T)
209208
def reporting(
210209
op: WrappedResult[T] ?=> String,
211210
printer: config.Printers.Printer = config.Printers.default): T = {
212211
printer.println(op(using WrappedResult(x)))
213212
x
214213
}
215-
}
216214

217-
extension [T](x: T) {
215+
extension [T](x: T)
218216
def assertingErrorsReported(using Context): T = {
219217
assert(ctx.reporter.errorsReported)
220218
x
@@ -223,9 +221,12 @@ object Decorators {
223221
assert(ctx.reporter.errorsReported, msg)
224222
x
225223
}
226-
}
227224

228-
extension (sc: StringContext) {
225+
extension [T <: AnyRef](xs: ::[T])
226+
def derivedCons(x1: T, xs1: List[T]) =
227+
if (xs.head eq x1) && (xs.tail eq xs1) then xs else x1 :: xs1
228+
229+
extension (sc: StringContext)
229230
/** General purpose string formatting */
230231
def i(args: Any*)(using Context): String =
231232
new StringFormatter(sc).assemble(args)
@@ -241,9 +242,8 @@ object Decorators {
241242
*/
242243
def ex(args: Any*)(using Context): String =
243244
explained(em(args: _*))
244-
}
245245

246-
extension [T <: AnyRef](arr: Array[T]):
246+
extension [T <: AnyRef](arr: Array[T])
247247
def binarySearch(x: T): Int = java.util.Arrays.binarySearch(arr.asInstanceOf[Array[Object]], x)
248248

249249
}

compiler/src/dotty/tools/dotc/core/Substituters.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package dotty.tools.dotc.core
22

3-
import Types._, Symbols._, Contexts._
3+
import Types._, Symbols._, Contexts._, Decorators._
44

55
/** Substitution operations on types. See the corresponding `subst` and
66
* `substThis` methods on class Type for an explanation.
@@ -16,6 +16,8 @@ object Substituters:
1616
else tp.derivedSelect(subst(tp.prefix, from, to, theMap))
1717
case _: ThisType =>
1818
tp
19+
case tp: AppliedType =>
20+
tp.map(subst(_, from, to, theMap))
1921
case _ =>
2022
(if (theMap != null) theMap else new SubstBindingMap(from, to))
2123
.mapOver(tp)
@@ -94,7 +96,7 @@ object Substituters:
9496
ts = ts.tail
9597
}
9698
tp
97-
case _: ThisType | _: BoundType =>
99+
case _: BoundType =>
98100
tp
99101
case _ =>
100102
(if (theMap != null) theMap else new SubstSymMap(from, to))
@@ -152,6 +154,8 @@ object Substituters:
152154
else tp.derivedSelect(substParams(tp.prefix, from, to, theMap))
153155
case _: ThisType =>
154156
tp
157+
case tp: AppliedType =>
158+
tp.map(substParams(_, from, to, theMap))
155159
case _ =>
156160
(if (theMap != null) theMap else new SubstParamsMap(from, to))
157161
.mapOver(tp)

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ object TypeOps:
100100
val sym = tp.symbol
101101
if (sym.isStatic && !sym.maybeOwner.seesOpaques || (tp.prefix `eq` NoPrefix)) tp
102102
else derivedSelect(tp, atVariance(variance max 0)(this(tp.prefix)))
103+
case tp: LambdaType =>
104+
mapOverLambda(tp) // special cased common case
103105
case tp: ThisType =>
104106
toPrefix(pre, cls, tp.cls)
105107
case _: BoundType =>
@@ -136,6 +138,9 @@ object TypeOps:
136138
tp2
137139
case tp1 => tp1
138140
}
141+
case tp: AppliedType =>
142+
val normed = tp.tryNormalize
143+
if normed.exists then normed else tp.map(simplify(_, theMap))
139144
case tp: TypeParamRef =>
140145
val tvar = ctx.typerState.constraint.typeVarOfParam(tp)
141146
if (tvar.exists) tvar else tp
@@ -147,7 +152,7 @@ object TypeOps:
147152
simplify(l, theMap) & simplify(r, theMap)
148153
case OrType(l, r) if !ctx.mode.is(Mode.Type) =>
149154
simplify(l, theMap) | simplify(r, theMap)
150-
case _: AppliedType | _: MatchType =>
155+
case _: MatchType =>
151156
val normed = tp.tryNormalize
152157
if (normed.exists) normed else mapOver
153158
case tp: MethodicType =>

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 73 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ object Types {
112112
def isProvisional(using Context): Boolean = mightBeProvisional && testProvisional
113113

114114
private def testProvisional(using Context): Boolean =
115+
class ProAcc extends TypeAccumulator[Boolean]:
116+
override def apply(x: Boolean, t: Type) = x || test(t, this)
115117
def test(t: Type, theAcc: TypeAccumulator[Boolean]): Boolean =
116118
if t.mightBeProvisional then
117119
t.mightBeProvisional = t match
@@ -127,16 +129,14 @@ object Types {
127129
}
128130
case t: TermRef =>
129131
!t.currentSymbol.isStatic && test(t.prefix, theAcc)
132+
case t: AppliedType =>
133+
t.fold(false, (x, tp) => x || test(tp, theAcc))
130134
case t: TypeVar =>
131135
!t.inst.exists || test(t.inst, theAcc)
132136
case t: LazyRef =>
133137
!t.completed || test(t.ref, theAcc)
134138
case _ =>
135-
val acc =
136-
if theAcc != null then theAcc
137-
else new TypeAccumulator[Boolean]:
138-
override def apply(x: Boolean, t: Type) = x || test(t, this)
139-
acc.foldOver(false, t)
139+
(if theAcc != null then theAcc else ProAcc()).foldOver(false, t)
140140
end if
141141
t.mightBeProvisional
142142
end test
@@ -3218,17 +3218,21 @@ object Types {
32183218

32193219
def newLikeThis(paramNames: List[ThisName], paramInfos: List[PInfo], resType: Type)(using Context): This =
32203220
def substParams(pinfos: List[PInfo], to: This): List[PInfo] = pinfos match
3221-
case pinfo :: rest =>
3222-
val pinfo1 = pinfo.subst(this, to).asInstanceOf[PInfo]
3223-
val rest1 = substParams(rest, to)
3224-
if (pinfo1 eq pinfo) && (rest1 eq rest) then pinfos
3225-
else pinfo1 :: rest1
3221+
case pinfos @ (pinfo :: rest) =>
3222+
pinfos.derivedCons(pinfo.subst(this, to).asInstanceOf[PInfo], substParams(rest, to))
32263223
case nil =>
32273224
nil
32283225
companion(paramNames)(
32293226
x => substParams(paramInfos, x),
32303227
x => resType.subst(this, x))
32313228

3229+
inline def map(inline op: Type => Type)(using Context) =
3230+
def mapParams(pinfos: List[PInfo]): List[PInfo] = pinfos match
3231+
case pinfos @ (pinfo :: rest) =>
3232+
pinfos.derivedCons(op(pinfo).asInstanceOf[PInfo], mapParams(rest))
3233+
case nil => nil
3234+
derivedLambdaType(paramNames, mapParams(paramInfos), op(resType))
3235+
32323236
protected def prefixString: String
32333237
override def toString: String = s"$prefixString($paramNames, $paramInfos, $resType)"
32343238
}
@@ -3287,6 +3291,8 @@ object Types {
32873291
private var myParamDependencyStatus: DependencyStatus = Unknown
32883292

32893293
private def depStatus(initial: DependencyStatus, tp: Type)(using Context): DependencyStatus =
3294+
class DepAcc extends TypeAccumulator[DependencyStatus]:
3295+
def apply(status: DependencyStatus, tp: Type) = compute(status, tp, this)
32903296
def combine(x: DependencyStatus, y: DependencyStatus) =
32913297
val status = (x & StatusMask) max (y & StatusMask)
32923298
val provisional = (x | y) & Provisional
@@ -3305,16 +3311,13 @@ object Types {
33053311
case _ =>
33063312
status1
33073313
}
3314+
case tp: TermRef => applyPrefix(tp)
3315+
case tp: AppliedType => tp.fold(status, compute(_, _, theAcc))
33083316
case tp: TypeVar if !tp.isInstantiated => combine(status, Provisional)
33093317
case TermParamRef(`thisLambdaType`, _) => TrueDeps
3310-
case tp: TermRef => applyPrefix(tp)
33113318
case _: ThisType | _: BoundType | NoPrefix => status
33123319
case _ =>
3313-
val acc =
3314-
if theAcc != null then theAcc
3315-
else new TypeAccumulator[DependencyStatus]:
3316-
def apply(status: DependencyStatus, tp: Type) = compute(status, tp, this)
3317-
acc.foldOver(status, tp)
3320+
(if theAcc != null then theAcc else DepAcc()).foldOver(status, tp)
33183321
compute(initial, tp, null)
33193322
end depStatus
33203323

@@ -3849,6 +3852,18 @@ object Types {
38493852
superType
38503853
}
38513854

3855+
inline def map(inline op: Type => Type)(using Context) =
3856+
def mapArgs(args: List[Type]): List[Type] = args match
3857+
case args @ (arg :: rest) => args.derivedCons(op(arg), mapArgs(rest))
3858+
case nil => nil
3859+
derivedAppliedType(op(tycon), mapArgs(args))
3860+
3861+
inline def fold[T](x: T, inline op: (T, Type) => T)(using Context): T =
3862+
def foldArgs(x: T, args: List[Type]): T = args match
3863+
case arg :: rest => foldArgs(op(x, arg), rest)
3864+
case nil => x
3865+
foldArgs(op(x, tycon), args)
3866+
38523867
override def tryNormalize(using Context): Type = tycon match {
38533868
case tycon: TypeRef =>
38543869
def tryMatchAlias = tycon.info match {
@@ -4954,10 +4969,29 @@ object Types {
49544969
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
49554970
tp.derivedLambdaType(tp.paramNames, formals, restpe)
49564971

4972+
protected def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match
4973+
case arg :: otherArgs if tparams.nonEmpty =>
4974+
val arg1 = arg match
4975+
case arg: TypeBounds => this(arg)
4976+
case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg))
4977+
val otherArgs1 = mapArgs(otherArgs, tparams.tail)
4978+
if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args
4979+
else arg1 :: otherArgs1
4980+
case nil =>
4981+
nil
4982+
4983+
protected def mapOverLambda(tp: LambdaType) =
4984+
val restpe = tp.resultType
4985+
val saved = variance
4986+
variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance
4987+
val ptypes1 = tp.paramInfos.mapConserve(this).asInstanceOf[List[tp.PInfo]]
4988+
variance = saved
4989+
derivedLambdaType(tp)(ptypes1, this(restpe))
4990+
49574991
/** Map this function over given type */
49584992
def mapOver(tp: Type): Type = {
4959-
record(s"mapOver ${getClass}")
4960-
record("mapOver total")
4993+
record(s"TypeMap mapOver ${getClass}")
4994+
record("TypeMap mapOver total")
49614995
val ctx = this.mapCtx // optimization for performance
49624996
given Context = ctx
49634997
tp match {
@@ -4972,27 +5006,12 @@ object Types {
49725006
// if `p <: q` then `p.A <: q.A`, and well-formedness requires that `A` is a member
49735007
// of `p`'s upper bound.
49745008
derivedSelect(tp, prefix1)
4975-
case _: ThisType
4976-
| _: BoundType
4977-
| NoPrefix => tp
49785009

49795010
case tp: AppliedType =>
4980-
def mapArgs(args: List[Type], tparams: List[ParamInfo]): List[Type] = args match {
4981-
case arg :: otherArgs if tparams.nonEmpty =>
4982-
val arg1 = arg match {
4983-
case arg: TypeBounds => this(arg)
4984-
case arg => atVariance(variance * tparams.head.paramVarianceSign)(this(arg))
4985-
}
4986-
val otherArgs1 = mapArgs(otherArgs, tparams.tail)
4987-
if ((arg1 eq arg) && (otherArgs1 eq otherArgs)) args
4988-
else arg1 :: otherArgs1
4989-
case nil =>
4990-
nil
4991-
}
49925011
derivedAppliedType(tp, this(tp.tycon), mapArgs(tp.args, tp.tyconTypeParams))
49935012

4994-
case tp: RefinedType =>
4995-
derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo))
5013+
case tp: LambdaType =>
5014+
mapOverLambda(tp)
49965015

49975016
case tp: AliasingBounds =>
49985017
derivedAlias(tp, atVariance(0)(this(tp.alias)))
@@ -5003,26 +5022,32 @@ object Types {
50035022
variance = -variance
50045023
derivedTypeBounds(tp, lo1, this(tp.hi))
50055024

5006-
case tp: RecType =>
5007-
derivedRecType(tp, this(tp.parent))
5008-
50095025
case tp: TypeVar =>
50105026
val inst = tp.instanceOpt
50115027
if (inst.exists) apply(inst) else tp
50125028

50135029
case tp: ExprType =>
50145030
derivedExprType(tp, this(tp.resultType))
50155031

5016-
case tp: LambdaType =>
5017-
def mapOverLambda = {
5018-
val restpe = tp.resultType
5019-
val saved = variance
5020-
variance = if (defn.MatchCase.isInstance(restpe)) 0 else -variance
5021-
val ptypes1 = tp.paramInfos.mapConserve(this).asInstanceOf[List[tp.PInfo]]
5022-
variance = saved
5023-
derivedLambdaType(tp)(ptypes1, this(restpe))
5024-
}
5025-
mapOverLambda
5032+
case tp @ AnnotatedType(underlying, annot) =>
5033+
val underlying1 = this(underlying)
5034+
if (underlying1 eq underlying) tp
5035+
else derivedAnnotatedType(tp, underlying1, mapOver(annot))
5036+
5037+
case _: ThisType
5038+
| _: BoundType
5039+
| NoPrefix =>
5040+
tp
5041+
5042+
case tp: ProtoType =>
5043+
tp.map(this)
5044+
5045+
case tp: RefinedType =>
5046+
derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo))
5047+
5048+
case tp: RecType =>
5049+
record("TypeMap.RecType")
5050+
derivedRecType(tp, this(tp.parent))
50265051

50275052
case tp @ SuperType(thistp, supertp) =>
50285053
derivedSuperType(tp, this(thistp), this(supertp))
@@ -5056,20 +5081,12 @@ object Types {
50565081
case tp: SkolemType =>
50575082
derivedSkolemType(tp, this(tp.info))
50585083

5059-
case tp @ AnnotatedType(underlying, annot) =>
5060-
val underlying1 = this(underlying)
5061-
if (underlying1 eq underlying) tp
5062-
else derivedAnnotatedType(tp, underlying1, mapOver(annot))
5063-
50645084
case tp: WildcardType =>
50655085
derivedWildcardType(tp, mapOver(tp.optBounds))
50665086

50675087
case tp: JavaArrayType =>
50685088
derivedJavaArrayType(tp, this(tp.elemType))
50695089

5070-
case tp: ProtoType =>
5071-
tp.map(this)
5072-
50735090
case _ =>
50745091
tp
50755092
}

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ object ProtoTypes {
617617
wildApprox(tp.refinedInfo, theMap, seen, internal))
618618
case tp: AliasingBounds => // default case, inlined for speed
619619
tp.derivedAlias(wildApprox(tp.alias, theMap, seen, internal))
620+
case tp: TypeBounds =>
621+
tp.derivedTypeBounds(
622+
wildApprox(tp.lo, theMap, seen, internal),
623+
wildApprox(tp.hi, theMap, seen, internal))
620624
case tp @ TypeParamRef(tl, _) if internal.contains(tl) => tp
621625
case tp @ TypeParamRef(poly, pnum) =>
622626
def wildApproxBounds(bounds: TypeBounds) =

0 commit comments

Comments
 (0)