Skip to content

Commit ac5d1e5

Browse files
author
EnzeXing
committed
Refactor pattern matching, skipping cases when safe to do so
1 parent 37206cc commit ac5d1e5

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

compiler/src/dotty/tools/dotc/transform/init/Objects.scala

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,12 @@ class Objects(using Context @constructorOnly):
609609
case (ValueSet(values), b : ValueElement) => ValueSet(values + b)
610610
case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b))
611611

612+
def remove(b: Value): Value = (a, b) match
613+
case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b)
614+
case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2))
615+
case (a: Ref, b: Ref) if a.equals(b) => Bottom
616+
case _ => a
617+
612618
def widen(height: Int)(using Context): Value =
613619
if height == 0 then Cold
614620
else
@@ -1348,29 +1354,25 @@ class Objects(using Context @constructorOnly):
13481354
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
13491355
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)
13501356

1351-
def evalCase(caseDef: CaseDef): Value =
1352-
evalPattern(scrutinee, caseDef.pat)
1353-
eval(caseDef.guard, thisV, klass)
1354-
eval(caseDef.body, thisV, klass)
1355-
13561357
/** Abstract evaluation of patterns.
13571358
*
13581359
* It augments the local environment for bound pattern variables. As symbols are globally
13591360
* unique, we can put them in a single environment.
13601361
*
13611362
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
13621363
*/
1363-
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
1364+
def evalPattern(scrutinee: Value, pat: Tree): (Type, Value) = log("match " + scrutinee.show + " against " + pat.show, printer, (_: (Type, Value))._2.show):
13641365
val trace2 = Trace.trace.add(pat)
13651366
pat match
13661367
case Alternative(pats) =>
1367-
for pat <- pats do evalPattern(scrutinee, pat)
1368-
scrutinee
1368+
val (types, values) = pats.map(evalPattern(scrutinee, _)).unzip()
1369+
val orType = types.fold(defn.NothingType)(OrType(_, _, false))
1370+
(orType, values.join)
13691371

13701372
case bind @ Bind(_, pat) =>
1371-
val value = evalPattern(scrutinee, pat)
1373+
val (tpe, value) = evalPattern(scrutinee, pat)
13721374
initLocal(bind.symbol, value)
1373-
scrutinee
1375+
(tpe, value)
13741376

13751377
case UnApply(fun, implicits, pats) =>
13761378
given Trace = trace2
@@ -1379,6 +1381,10 @@ class Objects(using Context @constructorOnly):
13791381
val funRef = fun1.tpe.asInstanceOf[TermRef]
13801382
val unapplyResTp = funRef.widen.finalResultType
13811383

1384+
val receiverType = fun1 match
1385+
case ident: Ident => funRef.prefix
1386+
case select: Select => select.qualifier.tpe
1387+
13821388
val receiver = fun1 match
13831389
case ident: Ident =>
13841390
evalType(funRef.prefix, thisV, klass)
@@ -1467,17 +1473,18 @@ class Objects(using Context @constructorOnly):
14671473
end if
14681474
end if
14691475
end if
1470-
scrutinee
1476+
(receiverType, scrutinee.filterType(receiverType))
14711477

14721478
case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
1473-
scrutinee
1479+
(defn.ThrowableType, scrutinee)
14741480

1475-
case Typed(pat, _) =>
1476-
evalPattern(scrutinee, pat)
1481+
case Typed(pat, typeTree) =>
1482+
val (_, value) = evalPattern(scrutinee.filterType(typeTree.tpe), pat)
1483+
(typeTree.tpe, value)
14771484

14781485
case tree =>
14791486
// For all other trees, the semantics is normal.
1480-
eval(tree, thisV, klass)
1487+
(defn.ThrowableType, eval(tree, thisV, klass))
14811488

14821489
end evalPattern
14831490

@@ -1501,12 +1508,12 @@ class Objects(using Context @constructorOnly):
15011508
if isWildcardStarArgList(pats) then
15021509
if pats.size == 1 then
15031510
// call .toSeq
1504-
val toSeqDenot = getMemberMethod(scrutineeType, nme.toSeq, toSeqType(elemType))
1511+
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
15051512
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
15061513
evalPattern(toSeqRes, pats.head)
15071514
else
15081515
// call .drop
1509-
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
1516+
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
15101517
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
15111518
for pat <- pats.init do evalPattern(applyRes, pat)
15121519
evalPattern(dropRes, pats.last)
@@ -1517,8 +1524,21 @@ class Objects(using Context @constructorOnly):
15171524
end if
15181525
end evalSeqPatterns
15191526

1527+
def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
1528+
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
1529+
(catchValue == Bottom && remainingScrutinee != Bottom)
15201530

1521-
cases.map(evalCase).join
1531+
var remainingScrutinee = scrutinee
1532+
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
1533+
for caseDef <- cases do
1534+
val (tpe, value) = evalPattern(remainingScrutinee, caseDef.pat)
1535+
eval(caseDef.guard, thisV, klass)
1536+
if !canSkipCase(remainingScrutinee, value) then
1537+
caseResults.addOne(eval(caseDef.body, thisV, klass))
1538+
if catchesAllOf(caseDef, tpe) then
1539+
remainingScrutinee = remainingScrutinee.remove(value)
1540+
1541+
caseResults.join
15221542
end patternMatch
15231543

15241544
/** Handle semantics of leaf nodes

0 commit comments

Comments
 (0)