diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index d17bfd0f7564..b650a0088de4 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -743,8 +743,6 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => Some(meth) case Block(Nil, expr) => unapply(expr) - case Inlined(_, bindings, expr) if bindings.forall(isPureBinding) => - unapply(expr) case _ => None } diff --git a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala index 460d0a61c252..b85454b8ba35 100644 --- a/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala +++ b/compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala @@ -158,35 +158,46 @@ class InlineReducer(inliner: Inliner)(using Context): * * where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix * refs among the ei's directly without creating an intermediate binding. + * + * This variant of beta-reduction preserves the integrity of `Inlined` tree nodes. */ def betaReduce(tree: Tree)(using Context): Tree = tree match { - case Apply(Select(cl @ closureDef(ddef), nme.apply), args) if defn.isFunctionType(cl.tpe) => - // closureDef also returns a result for closures wrapped in Inlined nodes. - // These need to be preserved. - def recur(cl: Tree): Tree = cl match - case Inlined(call, bindings, expr) => - cpy.Inlined(cl)(call, bindings, recur(expr)) - case _ => ddef.tpe.widen match - case mt: MethodType if ddef.paramss.head.length == args.length => - val bindingsBuf = new DefBuffer - val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) => - arg.tpe.dealias match { - case ref @ TermRef(NoPrefix, _) => ref.symbol - case _ => - paramBindingDef(name, paramtp, arg, bindingsBuf)( - using ctx.withSource(cl.source) - ).symbol + case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) => + val bindingsBuf = new DefBuffer + def recur(cl: Tree): Option[Tree] = cl match + case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol => + ddef.tpe.widen match + case mt: MethodType if ddef.paramss.head.length == args.length => + val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) => + arg.tpe.dealias match { + case ref @ TermRef(NoPrefix, _) => ref.symbol + case _ => + paramBindingDef(name, paramtp, arg, bindingsBuf)( + using ctx.withSource(cl.source) + ).symbol + } } - } - val expander = new TreeTypeMap( - oldOwners = ddef.symbol :: Nil, - newOwners = ctx.owner :: Nil, - substFrom = ddef.paramss.head.map(_.symbol), - substTo = argSyms) - Block(bindingsBuf.toList, expander.transform(ddef.rhs)).withSpan(tree.span) - case _ => tree - recur(cl) - case _ => tree + val expander = new TreeTypeMap( + oldOwners = ddef.symbol :: Nil, + newOwners = ctx.owner :: Nil, + substFrom = ddef.paramss.head.map(_.symbol), + substTo = argSyms) + Some(expander.transform(ddef.rhs)) + case _ => None + case Block(stats, expr) if stats.forall(isPureBinding) => + recur(expr).map(cpy.Block(cl)(stats, _)) + case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) => + recur(expr).map(cpy.Inlined(cl)(call, bindings, _)) + case Typed(expr, tpt) => + recur(expr) + case _ => None + recur(cl) match + case Some(reduced) => + Block(bindingsBuf.toList, reduced).withSpan(tree.span) + case None => + tree + case _ => + tree } /** The result type of reducing a match. It consists optionally of a list of bindings @@ -281,7 +292,7 @@ class InlineReducer(inliner: Inliner)(using Context): // Test case is pos-macros/i15971 val tptBinds = getBinds(Set.empty[TypeSymbol], tpt) val binds: Set[TypeSymbol] = pat match { - case UnApply(TypeApply(_, tpts), _, _) => + case UnApply(TypeApply(_, tpts), _, _) => getBinds(Set.empty[TypeSymbol], tpts) ++ tptBinds case _ => tptBinds } diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index ea9009de1d9e..a492e8785afc 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -600,10 +600,12 @@ class InlineBytecodeTests extends DottyBytecodeTest { val instructions = instructionsFromMethod(fun) val expected = // TODO room for constant folding List( - Op(ICONST_1), + Op(ICONST_2), VarOp(ISTORE, 1), + Op(ICONST_1), + VarOp(ISTORE, 2), Op(ICONST_2), - VarOp(ILOAD, 1), + VarOp(ILOAD, 2), Op(IADD), Op(ICONST_3), Op(IADD), diff --git a/tests/pos/i16374a.scala b/tests/pos/i16374a.scala new file mode 100644 index 000000000000..81ca17335393 --- /dev/null +++ b/tests/pos/i16374a.scala @@ -0,0 +1,7 @@ +def method(using String): String = ??? + +inline def inlineMethod(inline op: String => Unit)(using String): Unit = + println(op(method)) + +def test(using String) = + inlineMethod(c => print(c)) diff --git a/tests/pos/i16374b.scala b/tests/pos/i16374b.scala new file mode 100644 index 000000000000..3d68fbddb6e2 --- /dev/null +++ b/tests/pos/i16374b.scala @@ -0,0 +1,9 @@ +def method(using String): String = ??? + +inline def identity[T](inline x: T): T = x + +inline def inlineMethod(inline op: String => Unit)(using String): Unit = + println(identity(op)(method)) + +def test(using String) = + inlineMethod(c => print(c)) diff --git a/tests/pos/i16374c.scala b/tests/pos/i16374c.scala new file mode 100644 index 000000000000..33e11f324bfa --- /dev/null +++ b/tests/pos/i16374c.scala @@ -0,0 +1,7 @@ +def method(using String): String = ??? + +inline def inlineMethod(inline op: String => Unit)(using String): Unit = + println({ val a: Int = 1; op }.apply(method)) + +def test(using String) = + inlineMethod(c => print(c)) diff --git a/tests/pos/i16374d.scala b/tests/pos/i16374d.scala new file mode 100644 index 000000000000..5f0c8d715496 --- /dev/null +++ b/tests/pos/i16374d.scala @@ -0,0 +1,4 @@ +inline def inline1(inline f: Int => Int): Int => Int = i => f(1) +inline def inline2(inline f: Int => Int): Int = f(2) + 3 +def test: Int = inline2(inline1(2.+)) +