Skip to content

Commit d3f08a7

Browse files
committed
Improve invariant checks for erased terms
Instead of erasing an erased term to `???` we erase it to `erasedValue[T]`. This has 2 advantages, first, the term does not lose its type, and second, the term is still marked as erased. The second implies that if there is a bug in the compiler or a macro where the term might end outside an erased context, the code will not compiler. Currently, the code compiles and then throws when calling the spurious `???`. See scala#11996.
1 parent 280109e commit d3f08a7

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
243243
private object dropInlines extends TreeMap {
244244
override def transform(tree: Tree)(using Context): Tree = tree match {
245245
case Inlined(call, _, expansion) =>
246-
val newExpansion = tree.tpe match
247-
case ConstantType(c) => Literal(c)
248-
case _ => Typed(ref(defn.Predef_undefined), TypeTree(tree.tpe))
249-
cpy.Inlined(tree)(call, Nil, newExpansion.withSpan(tree.span))
246+
val newExpansion = PruneErasedDefs.trivialErasedTree(tree)
247+
cpy.Inlined(tree)(call, Nil, newExpansion)
250248
case _ => super.transform(tree)
251249
}
252250
}
@@ -282,7 +280,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
282280
tpd.cpy.Apply(tree)(
283281
tree.fun,
284282
tree.args.mapConserve(arg =>
285-
if (methType.isImplicitMethod && arg.span.isSynthetic) ref(defn.Predef_undefined)
283+
if (methType.isImplicitMethod && arg.span.isSynthetic)
284+
PruneErasedDefs.trivialErasedTree(arg)
286285
else dropInlines.transform(arg)))
287286
else
288287
tree
@@ -414,12 +413,12 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
414413
// case x: (_: Tree[?])
415414
case m @ MatchTypeTree(bounds, selector, cases) =>
416415
// Analog to the case above for match types
417-
def tranformIgnoringBoundsCheck(x: CaseDef): CaseDef =
416+
def transformIgnoringBoundsCheck(x: CaseDef): CaseDef =
418417
withMode(Mode.Pattern)(super.transform(x)).asInstanceOf[CaseDef]
419418
cpy.MatchTypeTree(tree)(
420419
super.transform(bounds),
421420
super.transform(selector),
422-
cases.mapConserve(tranformIgnoringBoundsCheck)
421+
cases.mapConserve(transformIgnoringBoundsCheck)
423422
)
424423
case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) =>
425424
superAcc.withInvalidCurrentClass(super.transform(tree))

compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import ast.tpd
2626
*/
2727
class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform =>
2828
import tpd._
29+
import PruneErasedDefs._
2930

3031
override def phaseName: String = PruneErasedDefs.name
3132

@@ -57,24 +58,24 @@ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform =>
5758
override def transformValDef(tree: ValDef)(using Context): Tree =
5859
val sym = tree.symbol
5960
if tree.symbol.isEffectivelyErased && !tree.rhs.isEmpty then
60-
cpy.ValDef(tree)(rhs = trivialErasedTree(tree))
61+
cpy.ValDef(tree)(rhs = trivialErasedTree(tree.rhs))
6162
else if hasUninitializedRHS(tree) then
6263
cpy.ValDef(tree)(rhs = cpy.Ident(tree.rhs)(nme.WILDCARD).withType(tree.tpt.tpe))
6364
else
6465
tree
6566

6667
override def transformDefDef(tree: DefDef)(using Context): Tree =
6768
if (tree.symbol.isEffectivelyErased && !tree.rhs.isEmpty)
68-
cpy.DefDef(tree)(rhs = trivialErasedTree(tree))
69+
cpy.DefDef(tree)(rhs = trivialErasedTree(tree.rhs))
6970
else tree
7071

71-
private def trivialErasedTree(tree: Tree)(using Context): Tree =
72-
tree.tpe.widenTermRefExpr.dealias.normalized match
73-
case ConstantType(c) => Literal(c)
74-
case _ => ref(defn.Predef_undefined)
75-
7672
}
7773

7874
object PruneErasedDefs {
75+
import tpd._
76+
7977
val name: String = "pruneErasedDefs"
78+
79+
def trivialErasedTree(tree: Tree)(using Context): Tree =
80+
ref(defn.Compiletime_erasedValue).appliedToType(tree.tpe).withSpan(tree.span)
8081
}

0 commit comments

Comments
 (0)