Skip to content

Commit a8bdfe5

Browse files
committed
Make inline override methods available at runtime
1 parent 9247e72 commit a8bdfe5

23 files changed

+185
-96
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ object Contexts {
5353
/** The current context */
5454
def curCtx(using ctx: Context): Context = ctx
5555

56+
type Ctx[+T] = Context ?=> T
57+
5658
/** A context is passed basically everywhere in dotc.
5759
* This is convenient but carries the risk of captured contexts in
5860
* objects that turn into space leaks. To combat this risk, here are some

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ object SymDenotations {
368368
annots.iterator.foreach(addAnnotation)
369369

370370
@tailrec
371-
private def dropOtherAnnotations(anns: List[Annotation], cls: Symbol)(implicit ctx: Context): List[Annotation] = anns match {
371+
final def dropOtherAnnotations(anns: List[Annotation], cls: Symbol)(implicit ctx: Context): List[Annotation] = anns match {
372372
case ann :: rest => if (ann matches cls) anns else dropOtherAnnotations(rest, cls)
373373
case Nil => Nil
374374
}
@@ -938,13 +938,15 @@ object SymDenotations {
938938
def isInlineMethod(implicit ctx: Context): Boolean =
939939
isAllOf(InlineMethod, butNot = Accessor)
940940

941+
def isInlineRetained: Ctx[Boolean] = is(Override)
942+
941943
/** Is this a Scala 2 macro */
942944
final def isScala2Macro(implicit ctx: Context): Boolean = is(Macro) && symbol.owner.is(Scala2x)
943945

944946
/** An erased value or an inline method.
945947
*/
946948
def isEffectivelyErased(implicit ctx: Context): Boolean =
947-
is(Erased) || isInlineMethod
949+
is(Erased) || isInlineMethod && !isInlineRetained
948950

949951
/** ()T and => T types should be treated as equivalent for this symbol.
950952
* Note: For the moment, we treat Scala-2 compiled symbols as loose matching,

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,12 @@ class TreePickler(pickler: TastyPickler) {
517517
def pickleAllParams =
518518
pickleParams(tree.tparams)
519519
pickleParamss(tree.vparamss)
520-
pickleDef(DEFDEF, tree.symbol, tree.tpt, tree.rhs, pickleAllParams)
520+
val meth = tree.symbol
521+
val rhs =
522+
if meth.isInlineMethod && meth.isInlineRetained
523+
then Alternative(List(Inliner.bodyToInline(meth), tree.rhs))
524+
else tree.rhs
525+
pickleDef(DEFDEF, meth, tree.tpt, rhs, pickleAllParams)
521526
case tree: TypeDef =>
522527
pickleDef(TYPEDEF, tree.symbol, tree.rhs)
523528
case tree: Template =>

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,9 @@ class TreeUnpickler(reader: TastyReader,
583583
else if (sym.isInlineMethod)
584584
sym.addAnnotation(LazyBodyAnnotation { ctx0 =>
585585
val ctx1 = localContext(sym)(ctx0).addMode(Mode.ReadPositions)
586-
implicit val ctx: Context = sourceChangeContext(Addr(0))(ctx1)
586+
given Context = sourceChangeContext(Addr(0))(ctx1)
587587
// avoids space leaks by not capturing the current context
588-
forkAt(rhsStart).readTerm()
588+
forkAt(rhsStart).readInlineBodyPart(compileTime = true)
589589
})
590590
goto(start)
591591
sym
@@ -775,10 +775,13 @@ class TreeUnpickler(reader: TastyReader,
775775
if (nothingButMods(end))
776776
EmptyTree
777777
else if (sym.isInlineMethod)
778-
// The body of an inline method is stored in an annotation, so no need to unpickle it again
779-
new Trees.Lazy[Tree] {
780-
def complete(implicit ctx: Context) = typer.Inliner.bodyToInline(sym)
781-
}
778+
if sym.isInlineRetained then
779+
readLater(end, rdr => ctx =>
780+
rdr.readInlineBodyPart(compileTime = false)(using ctx.retractMode(Mode.InSuperCall)))
781+
else // The body of an inline method is stored in an annotation, so no need to unpickle it again
782+
new Trees.Lazy[Tree] {
783+
def complete(implicit ctx: Context) = typer.Inliner.bodyToInline(sym)
784+
}
782785
else
783786
readLater(end, rdr => ctx => rdr.readTerm()(ctx.retractMode(Mode.InSuperCall)))
784787

@@ -1281,6 +1284,14 @@ class TreeUnpickler(reader: TastyReader,
12811284
setSpan(start, CaseDef(pat, guard, rhs))
12821285
}
12831286

1287+
def readInlineBodyPart(compileTime: Boolean): Ctx[Tree] =
1288+
if nextByte == ALTERNATIVE then
1289+
readByte()
1290+
readEnd()
1291+
if !compileTime then readTerm()
1292+
else assert(compileTime)
1293+
readTerm()
1294+
12841295
def readLater[T <: AnyRef](end: Addr, op: TreeReader => Context => T)(implicit ctx: Context): Trees.Lazy[T] =
12851296
readLaterWithOwner(end, op)(ctx)(ctx.owner)
12861297

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import Denotations._
1515
import SymDenotations._
1616
import StdNames.{nme, tpnme}
1717
import ast.{Trees, untpd}
18-
import typer.{Implicits, Namer, Applications}
18+
import typer.{Implicits, Namer, Applications, Inliner}
1919
import typer.ProtoTypes._
2020
import Trees._
2121
import TypeApplications._
@@ -726,7 +726,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
726726
if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY,
727727
sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree))
728728

729-
protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot
729+
protected def dropAnnotForModText(sym: Symbol): Boolean =
730+
sym == defn.BodyAnnot && !printDebug
730731

731732
protected def optAscription[T >: Untyped](tpt: Tree[T]): Text = optText(tpt)(": " ~ _)
732733

@@ -779,16 +780,20 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
779780
protected def defDefToText[T >: Untyped](tree: DefDef[T]): Text = {
780781
import untpd.{modsDeco => _}
781782
dclTextOr(tree) {
782-
val defKeyword = modText(tree.mods, tree.symbol, keywordStr("def"), isType = false)
783-
val isExtension = tree.hasType && tree.symbol.is(Extension)
783+
val meth = tree.symbol
784+
val defKeyword = modText(tree.mods, meth, keywordStr("def"), isType = false)
785+
val isExtension = tree.hasType && meth.is(Extension)
784786
withEnclosingDef(tree) {
785787
val (prefix, vparamss) =
786788
if (isExtension) (defKeyword ~~ paramsText(tree.vparamss.head) ~~ valDefText(nameIdText(tree)), tree.vparamss.tail)
787789
else (defKeyword ~~ valDefText(nameIdText(tree)), tree.vparamss)
788-
790+
val rhs =
791+
if meth.isInlineMethod && meth.isInlineRetained && !printDebug
792+
then Inliner.bodyToInline(meth)
793+
else tree.rhs
789794
addVparamssText(prefix ~ tparamsText(tree.tparams), vparamss) ~
790795
optAscription(tree.tpt) ~
791-
optText(tree.rhs)(" = " ~ _)
796+
optText(rhs)(" = " ~ _)
792797
}
793798
}
794799
}

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,32 @@ class Erasure extends Phase with DenotTransformer {
7878
val oldInfo = ref.info
7979
val newInfo = transformInfo(oldSymbol, oldInfo)
8080
val oldFlags = ref.flags
81-
val newFlags =
81+
var newFlags =
8282
if (oldSymbol.is(Flags.TermParam) && isCompacted(oldSymbol.owner)) oldFlags &~ Flags.Param
8383
else oldFlags &~ Flags.HasDefaultParamsFlags // HasDefaultParamsFlags needs to be dropped because overriding might become overloading
84-
84+
val oldAnnotations = ref.annotations
85+
var newAnnotations = oldAnnotations
86+
if oldSymbol.isInlineMethod && oldSymbol.isInlineRetained then
87+
newFlags = newFlags &~ Flags.Inline
88+
newAnnotations = newAnnotations.filterConserve(_.symbol != defn.BodyAnnot)
8589
// TODO: define derivedSymDenotation?
86-
if ((oldSymbol eq newSymbol) && (oldOwner eq newOwner) && (oldName eq newName) && (oldInfo eq newInfo) && (oldFlags == newFlags))
90+
if (oldSymbol eq newSymbol)
91+
&& (oldOwner eq newOwner)
92+
&& (oldName eq newName)
93+
&& (oldInfo eq newInfo)
94+
&& (oldFlags == newFlags)
95+
&& (oldAnnotations eq newAnnotations)
96+
then
8797
ref
88-
else {
98+
else
8999
assert(!ref.is(Flags.PackageClass), s"trans $ref @ ${ctx.phase} oldOwner = $oldOwner, newOwner = $newOwner, oldInfo = $oldInfo, newInfo = $newInfo ${oldOwner eq newOwner} ${oldInfo eq newInfo}")
90-
ref.copySymDenotation(symbol = newSymbol, owner = newOwner, name = newName, initFlags = newFlags, info = newInfo)
91-
}
100+
ref.copySymDenotation(
101+
symbol = newSymbol,
102+
owner = newOwner,
103+
name = newName,
104+
initFlags = newFlags,
105+
info = newInfo,
106+
annotations = newAnnotations)
92107
}
93108
case ref: JointRefDenotation =>
94109
new UniqueRefDenotation(

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ object Inliner {
124124
)
125125
}
126126

127+
/** For a retained inline method: The inline expansion of the call to itself with its
128+
* parameters as arguments.
129+
*/
130+
def inlineSelfCall(mdef: DefDef)(using ctx: Context): Tree =
131+
val meth = mdef.symbol
132+
val origParams: List[Symbol] = mdef.tparams.map(_.symbol) ::: mdef.vparamss.flatten.map(_.symbol)
133+
val callParams: List[Symbol] = origParams.map(_.copy())
134+
val callOwner = meth.copy(flags = meth.flags &~ Inline)
135+
assert(ctx.owner != meth)
136+
inlineCall(
137+
ref(meth)
138+
.appliedToTypes(mdef.tparams.tpes)
139+
.appliedToArgss(mdef.vparamss.map(_.map(param => ref(param.symbol))))
140+
.withSpan(mdef.rhs.span)
141+
.subst(origParams, callParams)
142+
)(using ctx.withOwner(callOwner))
143+
.subst(callParams, origParams)
144+
.changeOwner(callOwner, meth)
145+
.reporting(i"icall for $meth: $result", inlining)
146+
127147
/** Replace `Inlined` node by a block that contains its bindings and expansion */
128148
def dropInlined(inlined: Inlined)(implicit ctx: Context): Tree =
129149
if (enclosingInlineds.nonEmpty) inlined // Remove in the outer most inlined call

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2418,7 +2418,10 @@ class Typer extends Namer
24182418
* Overwritten in Retyper to return `mdef` unchanged.
24192419
*/
24202420
protected def inlineExpansion(mdef: DefDef)(implicit ctx: Context): Tree =
2421-
tpd.cpy.DefDef(mdef)(rhs = Inliner.bodyToInline(mdef.symbol))
2421+
val meth = mdef.symbol
2422+
tpd.cpy.DefDef(mdef)(rhs =
2423+
if meth.isInlineRetained then Inliner.inlineSelfCall(mdef)
2424+
else Inliner.bodyToInline(meth)) // TODO: make that an empty tree
24222425

24232426
def typedExpr(tree: untpd.Tree, pt: Type = WildcardType)(implicit ctx: Context): Tree =
24242427
typed(tree, pt)(ctx retractMode Mode.PatternOrTypeBits)

tasty/src/dotty/tools/tasty/TastyFormat.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ Standard-Section: "ASTs" TopLevelStat*
5959
TYPEDEF Length NameRef (type_Term | Template) Modifier* -- modifiers type name (= type | bounds) | moifiers class name template
6060
IMPORT Length qual_Term Selector* -- import qual selectors
6161
ValOrDefDef = VALDEF Length NameRef type_Term rhs_Term? Modifier* -- modifiers val name : type (= rhs)?
62-
DEFDEF Length NameRef TypeParam* Params* returnType_Term rhs_Term?
63-
Modifier* -- modifiers def name [typeparams] paramss : returnType (= rhs)?
62+
DEFDEF Length NameRef TypeParam* Params* returnType_Term
63+
rhs_Term? Modifier* -- modifiers def name [typeparams] paramss : returnType (= rhs)?
6464
Selector = IMPORTED name_NameRef -- name, "_" for normal wildcards, "" for given wildcards
6565
RENAMED to_NameRef -- => name
6666
BOUNDED type_Term -- type bound
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package test
2+
import scala.util.FromDigits
3+
import scala.quoted._
4+
5+
object BigFloatFromDigitsImpl:
6+
def apply(digits: Expr[String])(using ctx: QuoteContext): Expr[BigFloat] =
7+
digits match
8+
case Const(ds) =>
9+
try
10+
val BigFloat(m, e) = BigFloat(ds)
11+
'{BigFloat(${Expr(m)}, ${Expr(e)})}
12+
catch case ex: FromDigits.FromDigitsException =>
13+
ctx.error(ex.getMessage)
14+
'{BigFloat(0, 0)}
15+
case digits =>
16+
'{BigFloat($digits)}

tests/neg/BigFloat/BigFloat_1.scala renamed to tests/neg-macros/BigFloat/BigFloat_1.scala

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,13 @@ object BigFloat extends App {
3030
BigFloat(BigInt(intPart), exponent)
3131
}
3232

33-
private def fromDigitsImpl(digits: Expr[String])(using ctx: QuoteContext): Expr[BigFloat] =
34-
digits match {
35-
case Const(ds) =>
36-
try {
37-
val BigFloat(m, e) = apply(ds)
38-
'{BigFloat(${Expr(m)}, ${Expr(e)})}
39-
}
40-
catch {
41-
case ex: FromDigits.FromDigitsException =>
42-
ctx.error(ex.getMessage)
43-
'{BigFloat(0, 0)}
44-
}
45-
case digits =>
46-
'{apply($digits)}
47-
}
48-
4933
class BigFloatFromDigits extends FromDigits.Floating[BigFloat] {
5034
def fromDigits(digits: String) = apply(digits)
5135
}
5236

5337
given BigFloatFromDigits {
5438
override inline def fromDigits(digits: String) = ${
55-
fromDigitsImpl('digits)
39+
BigFloatFromDigitsImpl('digits)
5640
}
5741
}
5842

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.util.FromDigits
2+
import scala.quoted._
3+
import Even._
4+
5+
object EvenFromDigitsImpl:
6+
def apply(digits: Expr[String])(using ctx: QuoteContext): Expr[Even] = digits match {
7+
case Const(ds) =>
8+
val ev =
9+
try evenFromDigits(ds)
10+
catch {
11+
case ex: FromDigits.FromDigitsException =>
12+
ctx.error(ex.getMessage)
13+
Even(0)
14+
}
15+
'{Even(${Expr(ev.n)})}
16+
case _ =>
17+
'{evenFromDigits($digits)}
18+
}

tests/neg-macros/GenericNumLits/Even_1.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,19 @@ import scala.quoted._
55
case class Even(n: Int)
66
object Even {
77

8-
private def evenFromDigits(digits: String): Even = {
8+
def evenFromDigits(digits: String): Even = {
99
val intValue = FromDigits.intFromDigits(digits)
1010
if (intValue % 2 == 0) Even(intValue)
1111
else throw FromDigits.MalformedNumber(s"$digits is odd")
1212
}
1313

14-
private def evenFromDigitsImpl(digits: Expr[String])(using ctx: QuoteContext): Expr[Even] = digits match {
15-
case Const(ds) =>
16-
val ev =
17-
try evenFromDigits(ds)
18-
catch {
19-
case ex: FromDigits.FromDigitsException =>
20-
ctx.error(ex.getMessage)
21-
Even(0)
22-
}
23-
'{Even(${Expr(ev.n)})}
24-
case _ =>
25-
'{evenFromDigits($digits)}
26-
}
27-
2814
class EvenFromDigits extends FromDigits[Even] {
2915
def fromDigits(digits: String) = evenFromDigits(digits)
3016
}
3117

3218
given EvenFromDigits {
3319
override inline def fromDigits(digits: String) = ${
34-
evenFromDigitsImpl('digits)
20+
EvenFromDigitsImpl('digits)
3521
}
3622
}
3723
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import scala.util.FromDigits
2+
import scala.quoted._
3+
import Even._
4+
5+
object EvenFromDigitsImpl:
6+
def apply(digits: Expr[String])(using ctx: QuoteContext): Expr[Even] = digits match {
7+
case Const(ds) =>
8+
val ev =
9+
try evenFromDigits(ds)
10+
catch {
11+
case ex: FromDigits.FromDigitsException =>
12+
ctx.error(ex.getMessage)
13+
Even(0)
14+
}
15+
'{Even(${Expr(ev.n)})}
16+
case _ =>
17+
'{evenFromDigits($digits)}
18+
}

tests/neg-with-compiler/GenericNumLits/Even_1.scala

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,19 @@ import scala.quoted._
55
case class Even(n: Int)
66
object Even {
77

8-
private def evenFromDigits(digits: String): Even = {
8+
def evenFromDigits(digits: String): Even = {
99
val intValue = FromDigits.intFromDigits(digits)
1010
if (intValue % 2 == 0) Even(intValue)
1111
else throw FromDigits.MalformedNumber(s"$digits is odd")
1212
}
1313

14-
private def evenFromDigitsImpl(digits: Expr[String])(using ctx: QuoteContext): Expr[Even] = digits match {
15-
case Const(ds) =>
16-
val ev =
17-
try evenFromDigits(ds)
18-
catch {
19-
case ex: FromDigits.FromDigitsException =>
20-
ctx.error(ex.getMessage)
21-
Even(0)
22-
}
23-
'{Even(${Expr(ev.n)})}
24-
case _ =>
25-
'{evenFromDigits($digits)}
26-
}
27-
2814
class EvenFromDigits extends FromDigits[Even] {
2915
def fromDigits(digits: String) = evenFromDigits(digits)
3016
}
3117

3218
given EvenFromDigits {
3319
override inline def fromDigits(digits: String) = ${
34-
evenFromDigitsImpl('digits)
20+
EvenFromDigitsImpl('digits)
3521
}
3622
}
3723
}

0 commit comments

Comments
 (0)