diff --git a/compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala b/compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala index e8ed21588799..bedcd8c05724 100644 --- a/compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala +++ b/compiler/src/dotty/tools/dotc/core/quoted/PickledQuotes.scala @@ -8,10 +8,11 @@ import dotty.tools.dotc.core.Contexts._ import dotty.tools.dotc.core.Decorators._ import dotty.tools.dotc.core.StdNames._ import dotty.tools.dotc.core.NameKinds +import dotty.tools.dotc.core.Mode import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Types.Type import dotty.tools.dotc.core.tasty.TreePickler.Hole -import dotty.tools.dotc.core.tasty.{TastyPickler, TastyPrinter, TastyString} +import dotty.tools.dotc.core.tasty.{PositionPickler, TastyPickler, TastyPrinter, TastyString} import dotty.tools.dotc.core.tasty.TreeUnpickler.UnpickleMode import scala.quoted.Types._ @@ -66,13 +67,13 @@ object PickledQuotes { /** Unpickle the tree contained in the TastyExpr */ private def unpickleExpr(expr: TastyExpr[_])(implicit ctx: Context): Tree = { val tastyBytes = TastyString.unpickle(expr.tasty) - unpickle(tastyBytes, expr.args, isType = false) + unpickle(tastyBytes, expr.args, isType = false)(ctx.addMode(Mode.ReadPositions)) } /** Unpickle the tree contained in the TastyType */ private def unpickleType(ttpe: TastyType[_])(implicit ctx: Context): Tree = { val tastyBytes = TastyString.unpickle(ttpe.tasty) - unpickle(tastyBytes, ttpe.args, isType = true) + unpickle(tastyBytes, ttpe.args, isType = true)(ctx.addMode(Mode.ReadPositions)) } // TASTY picklingtests/pos/quoteTest.scala @@ -85,6 +86,8 @@ object PickledQuotes { treePkl.compactify() pickler.addrOfTree = treePkl.buf.addrOfTree pickler.addrOfSym = treePkl.addrOfSym + if (tree.pos.exists) + new PositionPickler(pickler, treePkl.buf.addrOfTree).picklePositions(tree :: Nil) if (pickling ne noPrinter) println(i"**** pickling quote of \n${tree.show}") diff --git a/tests/pos/i4734/Macro_1.scala b/tests/pos/i4734/Macro_1.scala new file mode 100644 index 000000000000..2d4155aed9a5 --- /dev/null +++ b/tests/pos/i4734/Macro_1.scala @@ -0,0 +1,12 @@ +import scala.annotation.tailrec +import scala.quoted._ + +object Macros { + transparent def unrolledForeach(f: Int => Int): Int = + ~unrolledForeachImpl('(f)) + + def unrolledForeachImpl(f: Expr[Int => Int]): Expr[Int] = '{ + val size: Int = 5 + (~f)(3) + } +} diff --git a/tests/pos/i4734/Test_2.scala b/tests/pos/i4734/Test_2.scala new file mode 100644 index 000000000000..4ee69acc363f --- /dev/null +++ b/tests/pos/i4734/Test_2.scala @@ -0,0 +1,8 @@ +import scala.quoted._ +import Macros._ + +object Test { + def main(args: Array[String]): Unit = { + unrolledForeach((x: Int) => 2) + } +} diff --git a/tests/run/i4734.check b/tests/run/i4734.check new file mode 100644 index 000000000000..2824511ad5fd --- /dev/null +++ b/tests/run/i4734.check @@ -0,0 +1,21 @@ +0 +2 +4 +6 +8 +10 +12 +14 +16 +18 +20 +22 +24 +26 +28 +30 +32 +34 +36 +38 +40 diff --git a/tests/run/i4734/Macro_1.scala b/tests/run/i4734/Macro_1.scala new file mode 100644 index 000000000000..1171966b6590 --- /dev/null +++ b/tests/run/i4734/Macro_1.scala @@ -0,0 +1,33 @@ +import scala.annotation.tailrec +import scala.quoted._ + +object Macros { + transparent def unrolledForeach(seq: IndexedSeq[Int], f: => Int => Unit, transparent unrollSize: Int): Unit = // or f: Int => Unit + ~unrolledForeachImpl('(seq), '(f), unrollSize) + + def unrolledForeachImpl(seq: Expr[IndexedSeq[Int]], f: Expr[Int => Unit], unrollSize: Int): Expr[Unit] = '{ + val size = (~seq).length + assert(size % (~unrollSize.toExpr) == 0) // for simplicity of the implementation + var i = 0 + while (i < size) { + ~{ + for (j <- new UnrolledRange(0, unrollSize)) '{ + val index = i + ~j.toExpr + val element = (~seq)(index) + ~f('(element)) // or `(~f)(element)` if `f` should not be inlined + } + } + i += ~unrollSize.toExpr + } + + } + + class UnrolledRange(start: Int, end: Int) { + def foreach(f: Int => Expr[Unit]): Expr[Unit] = { + @tailrec def loop(i: Int, acc: Expr[Unit]): Expr[Unit] = + if (i >= 0) loop(i - 1, '{ ~f(i); ~acc }) + else acc + loop(end - 1, '()) + } + } +} diff --git a/tests/run/i4734/Test_2.scala b/tests/run/i4734/Test_2.scala new file mode 100644 index 000000000000..72aa3f715ee8 --- /dev/null +++ b/tests/run/i4734/Test_2.scala @@ -0,0 +1,9 @@ +import scala.quoted._ +import Macros._ + +object Test { + def main(args: Array[String]): Unit = { + val seq = IndexedSeq.tabulate[Int](21)(x => x) + unrolledForeach(seq, (x: Int) => println(2*x), 3) + } +}