Skip to content

Commit 09353b7

Browse files
committed
Simpler version of the abstractions
1 parent 9373bdf commit 09353b7

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
21
2+
Complex(7,9)
3+
0.+(1.*(1)).+(2.*(0)).+(4.*(1)).+(8.*(0)).+(16.*(1))
4+
{
5+
val arr3: scala.Array[scala.Int] = scala.Array.apply(1, 2, 4, 8, 16)
6+
val arr4: scala.Array[scala.Int] = scala.Array.apply(1, 0, 1, 0, 1)
7+
var sum: scala.Int = 0
8+
scala.Predef.intWrapper(0).until(scala.Predef.intArrayOps(arr3).size).foreach[scala.Unit](((i: scala.Int) => sum = sum.+(arr3.apply(i).*(arr4.apply(i)))))
9+
(sum: scala.Int)
10+
}
11+
{
12+
val arr3: scala.Array[scala.Int] = scala.Array.apply(1, 2, 4, 8, 16)
13+
arr3.apply(0).+(arr3.apply(2)).+(arr3.apply(4))
14+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import scala.quoted._
2+
3+
trait Ring[T] {
4+
val zero: T
5+
val one: T
6+
val add: (x: T, y: T) => T
7+
val sub: (x: T, y: T) => T
8+
val mul: (x: T, y: T) => T
9+
}
10+
11+
class RingInt extends Ring[Int] {
12+
val zero = 0
13+
val one = 1
14+
val add = (x, y) => x + y
15+
val sub = (x, y) => x - y
16+
val mul = (x, y) => x * y
17+
}
18+
19+
class RingIntExpr extends Ring[Expr[Int]] {
20+
val zero = '(0)
21+
val one = '(1)
22+
val add = (x, y) => '(~x + ~y)
23+
val sub = (x, y) => '(~x - ~y)
24+
val mul = (x, y) => '(~x * ~y)
25+
}
26+
27+
class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] {
28+
val zero = Complex(u.zero, u.zero)
29+
val one = Complex(u.one, u.zero)
30+
val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im))
31+
val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im))
32+
val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
33+
}
34+
35+
sealed trait PV[T] {
36+
def expr(implicit l: Liftable[T]): Expr[T]
37+
}
38+
case class Sta[T](x: T) extends PV[T] {
39+
def expr(implicit l: Liftable[T]): Expr[T] = x.toExpr
40+
}
41+
case class Dyn[T](x: Expr[T]) extends PV[T] {
42+
def expr(implicit l: Liftable[T]): Expr[T] = x
43+
}
44+
45+
case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends Ring[PV[U]] {
46+
type T = PV[U]
47+
48+
import staRing._
49+
import dynRing._
50+
51+
val zero: T = Sta(staRing.zero)
52+
val one: T = Sta(staRing.one)
53+
val add = (x: T, y: T) => (x, y) match {
54+
case (Sta(staRing.zero), x) => x
55+
case (x, Sta(staRing.zero)) => x
56+
case (Sta(x), Sta(y)) => Sta(staRing.add(x, y))
57+
case (x, y) => Dyn(dynRing.add(x.expr, y.expr))
58+
}
59+
val sub = (x: T, y: T) => (x, y) match {
60+
case (Sta(staRing.zero), x) => x
61+
case (x, Sta(staRing.zero)) => x
62+
case (Sta(x), Sta(y)) => Sta(staRing.sub(x, y))
63+
case (x, y) => Dyn(dynRing.sub(x.expr, y.expr))
64+
}
65+
val mul = (x: T, y: T) => (x, y) match {
66+
case (Sta(staRing.zero), _) => Sta(staRing.zero)
67+
case (_, Sta(staRing.zero)) => Sta(staRing.zero)
68+
case (Sta(staRing.one), x) => x
69+
case (x, Sta(staRing.one)) => x
70+
case (Sta(x), Sta(y)) => Sta(staRing.mul(x, y))
71+
case (x, y) => Dyn(dynRing.mul(x.expr, y.expr))
72+
}
73+
}
74+
75+
case class Complex[T](re: T, im: T)
76+
77+
case class Vec[Idx, T](size: Idx, get: Idx => T) {
78+
def map[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i)))
79+
def zipWith[U, V](other: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = Vec(size, i => f(get(i), other.get(i)))
80+
}
81+
82+
83+
trait VecOps[Idx, T] {
84+
val reduce: ((T, T) => T, T, Vec[Idx, T]) => T
85+
}
86+
87+
class StaticVecOps[T] extends VecOps[Int, T] {
88+
val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => {
89+
var sum = zero
90+
for (i <- 0 until vec.size)
91+
sum = plus(sum, vec.get(i))
92+
sum
93+
}
94+
}
95+
96+
class StaticVecOptOps[T] extends VecOps[Int, T] {
97+
val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => {
98+
var sum = zero
99+
for (i <- 0 until vec.size)
100+
sum = plus(sum, vec.get(i))
101+
sum
102+
}
103+
}
104+
105+
class ExprVecOps[T: Type] extends VecOps[Expr[Int], Expr[T]] {
106+
val reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = (plus, zero, vec) => '{
107+
var sum = ~zero
108+
for (i <- 0 until ~vec.size)
109+
sum = ~{ plus('(sum), vec.get('(i))) }
110+
sum
111+
}
112+
}
113+
114+
class Blas1[Idx, T](r: Ring[T], ops: VecOps[Idx, T]) {
115+
def dot(v1: Vec[Idx, T], v2: Vec[Idx, T]): T = ops.reduce(r.add, r.zero, v1.zipWith(v2, r.mul))
116+
}
117+
118+
object Test {
119+
120+
implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make
121+
122+
def main(args: Array[String]): Unit = {
123+
val arr1 = Array(1, 2, 4, 8, 16)
124+
val arr2 = Array(1, 0, 1, 0, 1)
125+
126+
val vec1 = new Vec(arr1.size, i => arr1(i))
127+
val vec2 = new Vec(arr2.size, i => arr2(i))
128+
val blasInt = new Blas1(new RingInt, new StaticVecOps)
129+
println(blasInt.dot(vec1, vec2))
130+
131+
val vec3 = new Vec(arr1.size, i => Complex(2, arr2(i)))
132+
val vec4 = new Vec(arr2.size, i => Complex(1, arr2(i)))
133+
val blasComplexInt = new Blas1(new RingComplex(new RingInt), new StaticVecOps)
134+
println(blasComplexInt.dot(vec3, vec4))
135+
136+
val vec5 = new Vec(5, i => arr1(i).toExpr)
137+
val vec6 = new Vec(5, i => arr2(i).toExpr)
138+
val blasStaticIntExpr = new Blas1(new RingIntExpr, new StaticVecOps)
139+
println(blasStaticIntExpr.dot(vec5, vec6).show)
140+
141+
142+
143+
val code = '{
144+
val arr3 = Array(1, 2, 4, 8, 16)
145+
val arr4 = Array(1, 0, 1, 0, 1)
146+
~{
147+
val vec7 = new Vec('(arr3.size), i => '(arr3(~i)))
148+
val vec8 = new Vec('(arr4.size), i => '(arr4(~i)))
149+
val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps)
150+
blasExprIntExpr.dot(vec7, vec8)
151+
}
152+
153+
}
154+
println(code.show)
155+
156+
157+
{
158+
val vec5 = new Vec[Int, PV[Int]](5, i => Dyn(arr1(i).toExpr))
159+
val vec6 = new Vec[Int, PV[Int]](5, i => Sta(arr2(i)))
160+
val blasStaticIntExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
161+
blasStaticIntExpr.dot(vec5, vec6).expr.show
162+
}
163+
164+
{
165+
val code = '{
166+
val arr3 = Array(1, 2, 4, 8, 16)
167+
~{
168+
val vec7 = new Vec[Int, PV[Int]](5, i => Dyn('(arr3(~i.toExpr))))
169+
val vec8 = new Vec[Int, PV[Int]](5, i => Sta(arr2(i)))
170+
val blasExprIntExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
171+
blasExprIntExpr.dot(vec7, vec8).expr
172+
}
173+
174+
}
175+
println(code.show)
176+
}
177+
}
178+
179+
}

0 commit comments

Comments
 (0)