Skip to content

Commit faf9538

Browse files
Merge pull request #13755 from dotty-staging/fix-13747
Change order of proxy evaluation when inlining
2 parents 8947f38 + f7827eb commit faf9538

File tree

2 files changed

+52
-20
lines changed

2 files changed

+52
-20
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import quoted.QuoteUtils
3737
object Inliner {
3838
import tpd._
3939

40+
private type DefBuffer = mutable.ListBuffer[ValOrDefDef]
41+
4042
/** `sym` is an inline method with a known body to inline.
4143
*/
4244
def hasBodyToInline(sym: SymDenotation)(using Context): Boolean =
@@ -413,7 +415,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
413415

414416
private val methPart = funPart(call)
415417
private val callTypeArgs = typeArgss(call).flatten
416-
private val rawCallValueArgss = termArgss(call)
418+
private val callValueArgss = termArgss(call)
417419
private val inlinedMethod = methPart.symbol
418420
private val inlineCallPrefix =
419421
qualifier(methPart).orElse(This(inlinedMethod.enclosingClass.asClass))
@@ -465,14 +467,14 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
465467
/** A binding for the parameter of an inline method. This is a `val` def for
466468
* by-value parameters and a `def` def for by-name parameters. `val` defs inherit
467469
* inline annotations from their parameters. The generated `def` is appended
468-
* to `bindingsBuf`.
470+
* to `buf`.
469471
* @param name the name of the parameter
470472
* @param formal the type of the parameter
471473
* @param arg the argument corresponding to the parameter
472-
* @param bindingsBuf the buffer to which the definition should be appended
474+
* @param buf the buffer to which the definition should be appended
473475
*/
474476
private def paramBindingDef(name: Name, formal: Type, arg0: Tree,
475-
bindingsBuf: mutable.ListBuffer[ValOrDefDef])(using Context): ValOrDefDef = {
477+
buf: DefBuffer)(using Context): ValOrDefDef = {
476478
val isByName = formal.dealias.isInstanceOf[ExprType]
477479
val arg = arg0 match {
478480
case Typed(arg1, tpt) if tpt.tpe.isRepeatedParam && arg1.tpe.derivesFrom(defn.ArrayClass) =>
@@ -501,23 +503,25 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
501503
else ValDef(boundSym, newArg)
502504
}.withSpan(boundSym.span)
503505
inlining.println(i"parameter binding: $binding, $argIsBottom")
504-
bindingsBuf += binding
506+
buf += binding
505507
binding
506508
}
507509

508-
/** Populate `paramBinding` and `bindingsBuf` by matching parameters with
510+
/** Populate `paramBinding` and `buf` by matching parameters with
509511
* corresponding arguments. `bindingbuf` will be further extended later by
510512
* proxies to this-references. Issue an error if some arguments are missing.
511513
*/
512514
private def computeParamBindings(
513-
tp: Type, targs: List[Tree], argss: List[List[Tree]], formalss: List[List[Type]]): Boolean =
515+
tp: Type, targs: List[Tree],
516+
argss: List[List[Tree]], formalss: List[List[Type]],
517+
buf: DefBuffer): Boolean =
514518
tp match
515519
case tp: PolyType =>
516520
tp.paramNames.lazyZip(targs).foreach { (name, arg) =>
517521
paramSpan(name) = arg.span
518522
paramBinding(name) = arg.tpe.stripTypeVar
519523
}
520-
computeParamBindings(tp.resultType, targs.drop(tp.paramNames.length), argss, formalss)
524+
computeParamBindings(tp.resultType, targs.drop(tp.paramNames.length), argss, formalss, buf)
521525
case tp: MethodType =>
522526
if argss.isEmpty then
523527
report.error(i"missing arguments for inline method $inlinedMethod", call.srcPos)
@@ -529,9 +533,9 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
529533
case _: SingletonType if isIdempotentPath(arg) =>
530534
arg.tpe
531535
case _ =>
532-
paramBindingDef(name, formal, arg, bindingsBuf).symbol.termRef
536+
paramBindingDef(name, formal, arg, buf).symbol.termRef
533537
}
534-
computeParamBindings(tp.resultType, targs, argss.tail, formalss.tail)
538+
computeParamBindings(tp.resultType, targs, argss.tail, formalss.tail, buf)
535539
case _ =>
536540
assert(targs.isEmpty)
537541
assert(argss.isEmpty)
@@ -810,7 +814,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
810814
def inlined(sourcePos: SrcPos): Tree = {
811815

812816
// Special handling of `requireConst` and `codeOf`
813-
rawCallValueArgss match
817+
callValueArgss match
814818
case (arg :: Nil) :: Nil =>
815819
if inlinedMethod == defn.Compiletime_requireConst then
816820
arg match
@@ -860,24 +864,35 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
860864
case TypeApply(fn, _) => paramTypess(fn, acc)
861865
case _ => acc
862866

863-
val callValueArgss = rawCallValueArgss.nestedMapConserve(mapOpaquesInValueArg)
867+
val paramBindings =
868+
val mappedCallValueArgss = callValueArgss.nestedMapConserve(mapOpaquesInValueArg)
869+
if mappedCallValueArgss ne callValueArgss then
870+
inlining.println(i"mapped value args = ${mappedCallValueArgss.flatten}%, %")
864871

865-
if callValueArgss ne rawCallValueArgss then
866-
inlining.println(i"mapped value args = ${callValueArgss.flatten}%, %")
872+
val paramBindingsBuf = new DefBuffer
873+
// Compute bindings for all parameters, appending them to bindingsBuf
874+
if !computeParamBindings(
875+
inlinedMethod.info, callTypeArgs,
876+
mappedCallValueArgss, paramTypess(call, Nil),
877+
paramBindingsBuf)
878+
then
879+
return call
867880

868-
// Compute bindings for all parameters, appending them to bindingsBuf
869-
if !computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss, paramTypess(call, Nil)) then
870-
return call
881+
paramBindingsBuf.toList
882+
end paramBindings
871883

872884
// make sure prefix is executed if it is impure
873-
if (!isIdempotentExpr(inlineCallPrefix)) registerType(inlinedMethod.owner.thisType)
885+
if !isIdempotentExpr(inlineCallPrefix) then registerType(inlinedMethod.owner.thisType)
874886

875887
// Register types of all leaves of inlined body so that the `paramProxy` and `thisProxy` maps are defined.
876888
rhsToInline.foreachSubTree(registerLeaf)
877889

878890
// Compute bindings for all this-proxies, appending them to bindingsBuf
879891
computeThisBindings()
880892

893+
// Parameter bindings come after this bindings, reflecting order of evaluation
894+
bindingsBuf ++= paramBindings
895+
881896
val inlineTyper = new InlineTyper(ctx.reporter.errorCount)
882897

883898
val inlineCtx = inlineContext(call).fresh.setTyper(inlineTyper).setNewScope
@@ -1190,7 +1205,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
11901205
private object InlineableArg {
11911206
lazy val paramProxies = paramProxy.values.toSet
11921207
def unapply(tree: Trees.Ident[?])(using Context): Option[Tree] = {
1193-
def search(buf: mutable.ListBuffer[ValOrDefDef]) = buf.find(_.name == tree.name)
1208+
def search(buf: DefBuffer) = buf.find(_.name == tree.name)
11941209
if (paramProxies.contains(tree.typeOpt))
11951210
search(bindingsBuf) match {
11961211
case Some(bind: ValOrDefDef) if bind.symbol.is(Inline) =>
@@ -1229,7 +1244,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
12291244
cpy.Inlined(cl)(call, bindings, recur(expr))
12301245
case _ => ddef.tpe.widen match
12311246
case mt: MethodType if ddef.paramss.head.length == args.length =>
1232-
val bindingsBuf = new mutable.ListBuffer[ValOrDefDef]
1247+
val bindingsBuf = new DefBuffer
12331248
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
12341249
arg.tpe.dealias match {
12351250
case ref @ TermRef(NoPrefix, _) => ref.symbol

tests/run/i13747.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
var res = ""
2+
trait Bar:
3+
def +(that: Bar): Bar = new Plus(this, that)
4+
transparent inline def -(that: Bar): Bar = new Minus(this, that)
5+
6+
class LHS extends Bar {res += "LHS "}
7+
class RHS extends Bar {res += "RHS "}
8+
9+
class Plus(lhs: Bar, rhs: Bar) extends Bar {res += "op"}
10+
class Minus(lhs: Bar, rhs: Bar) extends Bar {res += "op"}
11+
12+
@main def Test =
13+
val pls = new LHS + new RHS
14+
val plsRes = res
15+
res = ""
16+
val min = new LHS - new RHS
17+
assert(plsRes == res)

0 commit comments

Comments
 (0)