Skip to content

Commit c372fa1

Browse files
Merge pull request #10390 from dotty-staging/enum/recursive-gadt
Enable some recursive gadt to work with inline match
2 parents a129373 + bd6801b commit c372fa1

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,21 @@ object Inliner {
337337
def codeOf(arg: Tree, pos: SrcPos)(using Context): Tree =
338338
Literal(Constant(arg.show)).withSpan(pos.span)
339339
}
340+
341+
extension (tp: Type) {
342+
343+
/** same as widenTermRefExpr, but preserves modules and singleton enum values */
344+
private final def widenInlineScrutinee(using Context): Type = tp.stripTypeVar match {
345+
case tp: TermRef =>
346+
val sym = tp.termSymbol
347+
if sym.isAllOf(EnumCase, butNot=JavaDefined) || sym.is(Module) then tp
348+
else if !tp.isOverloaded then tp.underlying.widenExpr.widenInlineScrutinee
349+
else tp
350+
case _ => tp
351+
}
352+
353+
}
354+
340355
}
341356

342357
/** Produces an inlined version of `call` via its `inlined` method.
@@ -1003,7 +1018,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
10031018
* scrutinee as RHS and type that corresponds to RHS.
10041019
*/
10051020
def newTermBinding(sym: TermSymbol, rhs: Tree): Unit = {
1006-
val copied = sym.copy(info = rhs.tpe.widenTermRefExpr, coord = sym.coord, flags = sym.flags &~ Case).asTerm
1021+
val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord, flags = sym.flags &~ Case).asTerm
10071022
caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span)))
10081023
}
10091024

@@ -1121,7 +1136,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
11211136
def reduceSubPatterns(pats: List[Tree], selectors: List[Tree]): Boolean = (pats, selectors) match {
11221137
case (Nil, Nil) => true
11231138
case (pat :: pats1, selector :: selectors1) =>
1124-
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenTermRefExpr).asTerm
1139+
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm
11251140
val rhs = constToLiteral(selector)
11261141
elem.defTree = rhs
11271142
caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span)))

tests/run/enum-nat.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import Nat._
2+
import compiletime._
3+
4+
enum Nat:
5+
case Zero
6+
case Succ[N <: Nat.Refract](n: N)
7+
8+
object Nat:
9+
type Refract = Zero.type | Succ[_]
10+
11+
inline def toIntTypeLevel[N <: Nat]: Int = inline erasedValue[N] match
12+
case _: Zero.type => 0
13+
case _: Succ[n] => toIntTypeLevel[n] + 1
14+
15+
inline def toInt[N <: Nat.Refract](inline nat: N): Int = inline nat match
16+
case nat: Zero.type => 0
17+
case nat: Succ[n] => toInt(nat.n) + 1
18+
19+
inline def toIntUnapply[N <: Nat.Refract](inline nat: N): Int = inline nat match
20+
case Zero => 0
21+
case Succ(n) => toIntUnapply(n) + 1
22+
23+
inline def toIntTypeTailRec[N <: Nat, Acc <: Int]: Int = inline erasedValue[N] match
24+
case _: Zero.type => constValue[Acc]
25+
case _: Succ[n] => toIntTypeTailRec[n, S[Acc]]
26+
27+
inline def toIntErased[N <: Nat.Refract](inline nat: N): Int = toIntTypeTailRec[N, 0]
28+
29+
@main def Test: Unit =
30+
println("erased value:")
31+
assert(toIntTypeLevel[Succ[Succ[Succ[Zero.type]]]] == 3)
32+
println("type test:")
33+
assert(toInt(Succ(Succ(Succ(Zero)))) == 3)
34+
println("unapply:")
35+
assert(toIntUnapply(Succ(Succ(Succ(Zero)))) == 3)
36+
println("infer erased:")
37+
assert(toIntErased(Succ(Succ(Succ(Zero)))) == 3)

0 commit comments

Comments
 (0)