Skip to content

Commit d1a184c

Browse files
author
Aggelos Biboudis
authored
Merge pull request #4736 from dotty-staging/quoted-unrolled-loop
Add tests example of a partially unrolled loop
2 parents 292c51c + 8d209a2 commit d1a184c

File tree

5 files changed

+308
-0
lines changed

5 files changed

+308
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
2+
val size: scala.Int = arr.length
3+
var i: scala.Int = 0
4+
while (i.<(size)) {
5+
val element: scala.Int = arr.apply(i)
6+
f.apply(element)
7+
i = i.+(1)
8+
}
9+
})
10+
11+
((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
12+
val size: scala.Int = arr.length
13+
var i: scala.Int = 0
14+
while (i.<(size)) {
15+
val element: java.lang.String = arr.apply(i)
16+
f.apply(element)
17+
i = i.+(1)
18+
}
19+
})
20+
21+
((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
22+
val size: scala.Int = arr.length
23+
var i: scala.Int = 0
24+
while (i.<(size)) {
25+
val element: java.lang.String = arr.apply(i)
26+
f.apply(element)
27+
i = i.+(1)
28+
}
29+
})
30+
31+
((arr: scala.Array[scala.Int]) => {
32+
val size: scala.Int = arr.length
33+
var i: scala.Int = 0
34+
while (i.<(size)) {
35+
val element: scala.Int = arr.apply(i)
36+
37+
((i: scala.Int) => java.lang.System.out.println(i)).apply(element)
38+
i = i.+(1)
39+
}
40+
})
41+
42+
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
43+
val size: scala.Int = arr.length
44+
var i: scala.Int = 0
45+
if (size.%(3).!=(0)) throw new scala.Exception("...") else ()
46+
while (i.<(size)) {
47+
f.apply(arr.apply(i))
48+
f.apply(arr.apply(i.+(1)))
49+
f.apply(arr.apply(i.+(2)))
50+
i = i.+(3)
51+
}
52+
})
53+
54+
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
55+
val size: scala.Int = arr.length
56+
var i: scala.Int = 0
57+
if (size.%(4).!=(0)) throw new scala.Exception("...") else ()
58+
while (i.<(size)) {
59+
f.apply(arr.apply(i.+(0)))
60+
f.apply(arr.apply(i.+(1)))
61+
f.apply(arr.apply(i.+(2)))
62+
f.apply(arr.apply(i.+(3)))
63+
i = i.+(4)
64+
}
65+
})
66+
67+
{
68+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](4)
69+
array.update(0, 1)
70+
array.update(1, 2)
71+
array.update(2, 3)
72+
array.update(3, 4)
73+
(array: scala.Array[scala.Int])
74+
}
75+
76+
{
77+
val arr1: scala.Array[scala.Int] = {
78+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](4)
79+
array.update(0, 1)
80+
array.update(1, 3)
81+
array.update(2, 4)
82+
array.update(3, 5)
83+
(array: scala.Array[scala.Int])
84+
}
85+
val size: scala.Int = arr1.length
86+
var i: scala.Int = 0
87+
while (i.<(size)) {
88+
val element: scala.Int = arr1.apply(i)
89+
90+
((x: scala.Int) => scala.Predef.println(x)).apply(element)
91+
i = i.+(1)
92+
}
93+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import scala.annotation.tailrec
2+
import scala.quoted._
3+
4+
object Test {
5+
implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make
6+
7+
def main(args: Array[String]): Unit = {
8+
val code1 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach1('(arr), '(f)) }
9+
println(code1.show)
10+
println()
11+
12+
val code1Tpe = '{ (arr: Array[String], f: String => Unit) => ~foreach1Tpe1('(arr), '(f)) }
13+
println(code1Tpe.show)
14+
println()
15+
16+
val code1Tpe2 = '{ (arr: Array[String], f: String => Unit) => ~foreach1Tpe1('(arr), '(f)) }
17+
println(code1Tpe2.show)
18+
println()
19+
20+
val code2 = '{ (arr: Array[Int]) => ~foreach1('(arr), '(i => System.out.println(i))) }
21+
println(code2.show)
22+
println()
23+
24+
val code3 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach3('(arr), '(f)) }
25+
println(code3.show)
26+
println()
27+
28+
val code4 = '{ (arr: Array[Int], f: Int => Unit) => ~foreach4('(arr), '(f), 4) }
29+
println(code4.show)
30+
println()
31+
32+
val liftedArray = Array(1, 2, 3, 4).toExpr
33+
println(liftedArray.show)
34+
println()
35+
36+
37+
def printAll(arr: Array[Int]) = '{
38+
val arr1 = ~arr.toExpr
39+
~foreach1('(arr1), '(x => println(x)))
40+
}
41+
42+
println(printAll(Array(1, 3, 4, 5)).show)
43+
44+
}
45+
46+
def foreach1(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
47+
val size = (~arrRef).length
48+
var i = 0
49+
while (i < size) {
50+
val element: Int = (~arrRef)(i)
51+
(~f)(element)
52+
i += 1
53+
}
54+
}
55+
56+
def foreach1Tpe1[T](arrRef: Expr[Array[T]], f: Expr[T => Unit])(implicit t: Type[T]): Expr[Unit] = '{
57+
val size = (~arrRef).length
58+
var i = 0
59+
while (i < size) {
60+
val element: ~t = (~arrRef)(i)
61+
(~f)(element)
62+
i += 1
63+
}
64+
}
65+
66+
def foreach1Tpe2[T: Type](arrRef: Expr[Array[T]], f: Expr[T => Unit]): Expr[Unit] = '{
67+
val size = (~arrRef).length
68+
var i = 0
69+
while (i < size) {
70+
val element: T = (~arrRef)(i)
71+
(~f)(element)
72+
i += 1
73+
}
74+
}
75+
76+
def foreach2(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
77+
val size = (~arrRef).length
78+
var i = 0
79+
while (i < size) {
80+
val element = (~arrRef)(i)
81+
~f('(element)) // Use AppliedFuntion
82+
i += 1
83+
}
84+
}
85+
86+
def foreach3(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
87+
val size = (~arrRef).length
88+
var i = 0
89+
if (size % 3 != 0) throw new Exception("...")// for simplicity of the implementation
90+
while (i < size) {
91+
(~f)((~arrRef)(i))
92+
(~f)((~arrRef)(i + 1))
93+
(~f)((~arrRef)(i + 2))
94+
i += 3
95+
}
96+
}
97+
98+
def foreach3_2(arrRef: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
99+
val size = (~arrRef).length
100+
var i = 0
101+
if (size % 3 != 0) throw new Exception("...")// for simplicity of the implementation
102+
while (i < size) {
103+
(~f)((~arrRef)(i))
104+
(~f)((~arrRef)(i + 1))
105+
(~f)((~arrRef)(i + 2))
106+
i += 3
107+
}
108+
}
109+
110+
def foreach4(arrRef: Expr[Array[Int]], f: Expr[Int => Unit], unrollSize: Int): Expr[Unit] = '{
111+
val size = (~arrRef).length
112+
var i = 0
113+
if (size % ~unrollSize.toExpr != 0) throw new Exception("...") // for simplicity of the implementation
114+
while (i < size) {
115+
~foreachInRange(0, unrollSize)(j => '{ (~f)((~arrRef)(i + ~j.toExpr)) })
116+
i += ~unrollSize.toExpr
117+
}
118+
}
119+
120+
implicit object ArrayIntIsLiftable extends Liftable[Array[Int]] {
121+
override def toExpr(x: Array[Int]): Expr[Array[Int]] = '{
122+
val array = new Array[Int](~x.length.toExpr)
123+
~foreachInRange(0, x.length)(i => '{ array(~i.toExpr) = ~x(i).toExpr})
124+
array
125+
}
126+
}
127+
128+
def foreachInRange(start: Int, end: Int)(f: Int => Expr[Unit]): Expr[Unit] = {
129+
@tailrec def unroll(i: Int, acc: Expr[Unit]): Expr[Unit] =
130+
if (i < end) unroll(i + 1, '{ ~acc; ~f(i) }) else acc
131+
if (start < end) unroll(start + 1, f(start)) else '()
132+
}
133+
134+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
<log> start loop
2+
0
3+
6
4+
12
5+
<log> start loop
6+
18
7+
24
8+
30
9+
<log> start loop
10+
36
11+
42
12+
48
13+
<log> start loop
14+
54
15+
60
16+
66
17+
<log> start loop
18+
72
19+
78
20+
84
21+
<log> start loop
22+
90
23+
96
24+
102
25+
<log> start loop
26+
108
27+
114
28+
120
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import scala.annotation.tailrec
2+
import scala.quoted._
3+
4+
object Macro {
5+
6+
inline def unrolledForeach(inline unrollSize: Int, seq: Array[Int])(f: => Int => Unit): Unit = // or f: Int => Unit
7+
~unrolledForeachImpl(unrollSize, '(seq), '(f))
8+
9+
private def unrolledForeachImpl(unrollSize: Int, seq: Expr[Array[Int]], f: Expr[Int => Unit]): Expr[Unit] = '{
10+
val size = (~seq).length
11+
assert(size % (~unrollSize.toExpr) == 0) // for simplicity of the implementation
12+
var i = 0
13+
while (i < size) {
14+
println("<log> start loop")
15+
~{
16+
@tailrec def loop(j: Int, acc: Expr[Unit]): Expr[Unit] =
17+
if (j >= 0) loop(j - 1, '{ ~f('((~seq)(i + ~j.toExpr))); ~acc })
18+
else acc
19+
loop(unrollSize - 1, '())
20+
}
21+
i += ~unrollSize.toExpr
22+
}
23+
}
24+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
2+
import scala.quoted._
3+
4+
object Test {
5+
def main(args: Array[String]): Unit = {
6+
val arr = Array.tabulate[Int](21)(x => 3 * x)
7+
Macro.unrolledForeach(3, arr) { (x: Int) =>
8+
System.out.println(2 * x)
9+
}
10+
11+
/* unrooled code:
12+
13+
val size: Int = arr.length()
14+
assert(size % 3 == 0)
15+
var i: Int = 0
16+
while (i < size) {
17+
println("<log> start loop")
18+
val x$1: Int = arr(i)
19+
System.out.println(2 * x$1)
20+
val x$2: Int = arr(i + 1)
21+
System.out.println(2 * x$2)
22+
val x$3: Int = arr(i + 2)
23+
System.out.println(2 * x$3)
24+
i = i + 3
25+
}
26+
*/
27+
}
28+
29+
}

0 commit comments

Comments
 (0)