Skip to content

Commit bf30f20

Browse files
committed
Enable one encoding of recursive gadt to work with inline match
If the recursive part is fixed to a subtype of the union of the cases of the enum, enable inline match to reduce cases. Notes: this encoding could be supported by a compiletime.Refract[S] type to split cases of a sum type.
1 parent b7be482 commit bf30f20

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,22 @@ object Inliner {
337337
def codeOf(arg: Tree, pos: SrcPos)(using Context): Tree =
338338
Literal(Constant(arg.show)).withSpan(pos.span)
339339
}
340+
341+
extension (tp: Type) {
342+
/** same as widenTermRefExpr, but preserves modules and singleton enum values */
343+
private final def widenInlineScrutinee(using Context): Type = tp.stripTypeVar match {
344+
case tp as ModuleOrEnumValueRef() => tp
345+
case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenInlineScrutinee
346+
case _ => tp
347+
}
348+
}
349+
350+
private object ModuleOrEnumValueRef {
351+
def unapply(tp: TermRef)(using Context): Boolean =
352+
val sym = tp.termSymbol
353+
sym.isAllOf(EnumCase, butNot=JavaDefined) || sym.is(Module)
354+
}
355+
340356
}
341357

342358
/** Produces an inlined version of `call` via its `inlined` method.
@@ -1003,7 +1019,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
10031019
* scrutinee as RHS and type that corresponds to RHS.
10041020
*/
10051021
def newTermBinding(sym: TermSymbol, rhs: Tree): Unit = {
1006-
val copied = sym.copy(info = rhs.tpe.widenTermRefExpr, coord = sym.coord, flags = sym.flags &~ Case).asTerm
1022+
val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord, flags = sym.flags &~ Case).asTerm
10071023
caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span)))
10081024
}
10091025

@@ -1121,7 +1137,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
11211137
def reduceSubPatterns(pats: List[Tree], selectors: List[Tree]): Boolean = (pats, selectors) match {
11221138
case (Nil, Nil) => true
11231139
case (pat :: pats1, selector :: selectors1) =>
1124-
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenTermRefExpr).asTerm
1140+
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm
11251141
val rhs = constToLiteral(selector)
11261142
elem.defTree = rhs
11271143
caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span)))

tests/run/enum-nat.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import Nat._
2+
import compiletime._
3+
4+
enum Nat:
5+
case Zero
6+
case Succ[N <: Nat.Refract](n: N)
7+
8+
object Nat:
9+
type Refract = Zero.type | Succ[_]
10+
11+
inline def toIntTypeLevel[N <: Nat]: Int = inline erasedValue[N] match
12+
case _: Zero.type => 0
13+
case _: Succ[n] => toIntTypeLevel[n] + 1
14+
15+
inline def toInt[N <: Nat.Refract](inline nat: N): Int = inline nat match
16+
case nat: Zero.type => 0
17+
case nat: Succ[n] => toInt(nat.n) + 1
18+
19+
inline def toIntUnapply[N <: Nat.Refract](inline nat: N): Int = inline nat match
20+
case Zero => 0
21+
case Succ(n) => toIntUnapply(n) + 1
22+
23+
inline def toIntTypeTailRec[N <: Nat, Acc <: Int]: Int = inline erasedValue[N] match
24+
case _: Zero.type => constValue[Acc]
25+
case _: Succ[n] => toIntTypeTailRec[n, S[Acc]]
26+
27+
inline def toIntErased[N <: Nat.Refract](inline nat: N): Int = toIntTypeTailRec[N, 0]
28+
29+
@main def Test: Unit =
30+
println("erased value:")
31+
assert(toIntTypeLevel[Succ[Succ[Succ[Zero.type]]]] == 3)
32+
println("type test:")
33+
assert(toInt(Succ(Succ(Succ(Zero)))) == 3)
34+
println("unapply:")
35+
assert(toIntUnapply(Succ(Succ(Succ(Zero)))) == 3)
36+
println("infer erased:")
37+
assert(toIntErased(Succ(Succ(Succ(Zero)))) == 3)

0 commit comments

Comments
 (0)