@@ -10,26 +10,34 @@ trait Ring[T] {
10
10
11
11
class RingInt extends Ring [Int ] {
12
12
val zero = 0
13
- val one = 1
14
- val add = (x, y) => x + y
13
+ val one = 1
14
+ val add = (x, y) => x + y
15
15
val sub = (x, y) => x - y
16
16
val mul = (x, y) => x * y
17
17
}
18
18
19
19
class RingIntExpr extends Ring [Expr [Int ]] {
20
20
val zero = '(0)
21
- val one = '(1)
22
- val add = (x, y) => '(~x + ~y)
21
+ val one = '(1)
22
+ val add = (x, y) => '(~x + ~y)
23
23
val sub = (x, y) => '(~x - ~y)
24
24
val mul = (x, y) => '(~x * ~y)
25
25
}
26
26
27
27
class RingComplex [U ](u : Ring [U ]) extends Ring [Complex [U ]] {
28
28
val zero = Complex (u.zero, u.zero)
29
29
val one = Complex (u.one, u.zero)
30
- val add = (x, y) => Complex (u.add(x.re, y.re), u.add(x.im, y.im))
31
- val sub = (x, y) => Complex (u.sub(x.re, y.re), u.sub(x.im, y.im))
32
- val mul = (x, y) => Complex (u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
30
+ val add = (x, y) => Complex (u.add(x.re, y.re), u.add(x.im, y.im))
31
+ val sub = (x, y) => Complex (u.sub(x.re, y.re), u.sub(x.im, y.im))
32
+ val mul = (x, y) => Complex (u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
33
+ }
34
+
35
+ class RingComplexExpr [U : Type ](u : Ring [Expr [U ]]) extends Ring [Expr [Complex [U ]]] {
36
+ val zero = '(Complex(~u.zero, ~u.zero))
37
+ val one = '(Complex(~u.one, ~u.zero))
38
+ val add = (x, y) => '(Complex(~u.add( ' ((~ x).re), '((~y).re)), ~u.add( ' ((~ x).im), '((~y).im))))
39
+ val sub = (x, y) => '(Complex(~u.sub( ' ((~ x).re), '((~y).re)), ~u.sub( ' ((~ x).im), '((~y).im))))
40
+ val mul = (x, y) => '(Complex(~u.sub(u.mul( ' ((~ x).re), '((~y).re)), u.mul( ' ((~ x).im), '((~y).im))), ~u.add(u.mul( ' ((~ x).re), '((~y).im)), u.mul( ' ((~ x).im), '((~y).re)))))
33
41
}
34
42
35
43
sealed trait PV [T ] {
@@ -42,43 +50,48 @@ case class Dyn[T](x: Expr[T]) extends PV[T] {
42
50
def expr (implicit l : Liftable [T ]): Expr [T ] = x
43
51
}
44
52
45
- case class RingPV [U : Liftable ](staRing : Ring [U ], dynRing : Ring [Expr [U ]]) extends Ring [PV [U ]] {
46
- type T = PV [U ]
47
-
48
- import staRing ._
49
- import dynRing ._
50
-
51
- val zero : T = Sta (staRing.zero)
52
- val one : T = Sta (staRing.one)
53
- val add = (x : T , y : T ) => (x, y) match {
54
- case (Sta (staRing.zero), x) => x
55
- case (x, Sta (staRing.zero)) => x
56
- case (Sta (x), Sta (y)) => Sta (staRing.add(x, y))
57
- case (x, y) => Dyn (dynRing.add(x.expr, y.expr))
53
+ class RingPV [U : Liftable ](u : Ring [U ], eu : Ring [Expr [U ]]) extends Ring [PV [U ]] {
54
+ val zero : PV [U ] = Sta (u.zero)
55
+ val one : PV [U ] = Sta (u.one)
56
+ val add = (x : PV [U ], y : PV [U ]) => (x, y) match {
57
+ case (Sta (u.zero), x) => x
58
+ case (x, Sta (u.zero)) => x
59
+ case (Sta (x), Sta (y)) => Sta (u.add(x, y))
60
+ case (x, y) => Dyn (eu.add(x.expr, y.expr))
58
61
}
59
- val sub = (x : T , y : T ) => (x, y) match {
60
- case (Sta (staRing .zero), x) => x
61
- case (x, Sta (staRing .zero)) => x
62
- case (Sta (x), Sta (y)) => Sta (staRing .sub(x, y))
63
- case (x, y) => Dyn (dynRing .sub(x.expr, y.expr))
62
+ val sub = (x : PV [ U ] , y : PV [ U ] ) => (x, y) match {
63
+ case (Sta (u .zero), x) => x
64
+ case (x, Sta (u .zero)) => x
65
+ case (Sta (x), Sta (y)) => Sta (u .sub(x, y))
66
+ case (x, y) => Dyn (eu .sub(x.expr, y.expr))
64
67
}
65
- val mul = (x : T , y : T ) => (x, y) match {
66
- case (Sta (staRing .zero), _) => Sta (staRing .zero)
67
- case (_, Sta (staRing .zero)) => Sta (staRing .zero)
68
- case (Sta (staRing .one), x) => x
69
- case (x, Sta (staRing .one)) => x
70
- case (Sta (x), Sta (y)) => Sta (staRing .mul(x, y))
71
- case (x, y) => Dyn (dynRing .mul(x.expr, y.expr))
68
+ val mul = (x : PV [ U ] , y : PV [ U ] ) => (x, y) match {
69
+ case (Sta (u .zero), _) => Sta (u .zero)
70
+ case (_, Sta (u .zero)) => Sta (u .zero)
71
+ case (Sta (u .one), x) => x
72
+ case (x, Sta (u .one)) => x
73
+ case (Sta (x), Sta (y)) => Sta (u .mul(x, y))
74
+ case (x, y) => Dyn (eu .mul(x.expr, y.expr))
72
75
}
73
76
}
74
77
78
+
75
79
case class Complex [T ](re : T , im : T )
76
80
81
+ object Complex {
82
+ implicit def isLiftable [T : Type : Liftable ]: Liftable [Complex [T ]] = new Liftable [Complex [T ]] {
83
+ def toExpr (comp : Complex [T ]): Expr [Complex [T ]] = '(Complex(~comp.re.toExpr, ~comp.im.toExpr))
84
+ }
85
+ }
86
+
77
87
case class Vec [Idx , T ](size : Idx , get : Idx => T ) {
78
88
def map [U ](f : T => U ): Vec [Idx , U ] = Vec (size, i => f(get(i)))
79
89
def zipWith [U , V ](other : Vec [Idx , U ], f : (T , U ) => V ): Vec [Idx , V ] = Vec (size, i => f(get(i), other.get(i)))
80
90
}
81
91
92
+ object Vec {
93
+ def from [T ](elems : T * ): Vec [Int , T ] = new Vec (elems.size, i => elems(i))
94
+ }
82
95
83
96
trait VecOps [Idx , T ] {
84
97
val reduce : ((T , T ) => T , T , Vec [Idx , T ]) => T
@@ -93,20 +106,14 @@ class StaticVecOps[T] extends VecOps[Int, T] {
93
106
}
94
107
}
95
108
96
- class StaticVecOptOps [T ] extends VecOps [Int , T ] {
97
- val reduce : ((T , T ) => T , T , Vec [Int , T ]) => T = (plus, zero, vec) => {
98
- var sum = zero
99
- for (i <- 0 until vec.size)
100
- sum = plus(sum, vec.get(i))
101
- sum
102
- }
103
- }
104
-
105
109
class ExprVecOps [T : Type ] extends VecOps [Expr [Int ], Expr [T ]] {
106
110
val reduce : ((Expr [T ], Expr [T ]) => Expr [T ], Expr [T ], Vec [Expr [Int ], Expr [T ]]) => Expr [T ] = (plus, zero, vec) => ' {
107
111
var sum = ~ zero
108
- for (i <- 0 until ~ vec.size)
112
+ var i = 0
113
+ while (i < ~ vec.size) {
109
114
sum = ~ { plus('(sum), vec.get( ' (i))) }
115
+ i += 1
116
+ }
110
117
sum
111
118
}
112
119
}
@@ -120,60 +127,94 @@ object Test {
120
127
implicit val toolbox : scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox .make
121
128
122
129
def main (args : Array [String ]): Unit = {
123
- val arr1 = Array (1 , 2 , 4 , 8 , 16 )
130
+ val arr1 = Array (0 , 1 , 2 , 4 , 8 )
124
131
val arr2 = Array (1 , 0 , 1 , 0 , 1 )
132
+ val cmpxArr1 = Array (Complex (1 , 0 ), Complex (2 , 3 ), Complex (0 , 2 ), Complex (3 , 1 ))
133
+ val cmpxArr2 = Array (Complex (0 , 1 ), Complex (0 , 0 ), Complex (0 , 0 ), Complex (1 , 0 ))
125
134
126
135
val vec1 = new Vec (arr1.size, i => arr1(i))
127
136
val vec2 = new Vec (arr2.size, i => arr2(i))
137
+ val cmpxVec1 = new Vec (cmpxArr1.size, i => cmpxArr1(i))
138
+ val cmpxVec2 = new Vec (cmpxArr2.size, i => cmpxArr2(i))
139
+
128
140
val blasInt = new Blas1 (new RingInt , new StaticVecOps )
129
- println(blasInt.dot(vec1, vec2))
141
+ val res1 = blasInt.dot(vec1, vec2)
142
+ println(res1)
143
+ println()
130
144
131
- val vec3 = new Vec (arr1.size, i => Complex (2 , arr2(i)))
132
- val vec4 = new Vec (arr2.size, i => Complex (1 , arr2(i)))
133
145
val blasComplexInt = new Blas1 (new RingComplex (new RingInt ), new StaticVecOps )
134
- println(blasComplexInt.dot(vec3, vec4))
146
+ val res2 = blasComplexInt.dot(
147
+ cmpxVec1,
148
+ cmpxVec2
149
+ )
150
+ println(res2)
151
+ println()
135
152
136
- val vec5 = new Vec (5 , i => arr1(i).toExpr)
137
- val vec6 = new Vec (5 , i => arr2(i).toExpr)
138
153
val blasStaticIntExpr = new Blas1 (new RingIntExpr , new StaticVecOps )
139
- println(blasStaticIntExpr.dot(vec5, vec6).show)
140
-
141
-
142
-
143
- val code = ' {
144
- val arr3 = Array (1 , 2 , 4 , 8 , 16 )
145
- val arr4 = Array (1 , 0 , 1 , 0 , 1 )
146
- ~ {
147
- val vec7 = new Vec ('(arr3.size), i => ' (arr3(~ i)))
148
- val vec8 = new Vec ('(arr4.size), i => ' (arr4(~ i)))
149
- val blasExprIntExpr = new Blas1 (new RingIntExpr , new ExprVecOps )
150
- blasExprIntExpr.dot(vec7, vec8)
151
- }
152
-
154
+ val resCode1 = blasStaticIntExpr.dot(
155
+ vec1.map(_.toExpr),
156
+ vec2.map(_.toExpr)
157
+ )
158
+ println(resCode1.show)
159
+ println(resCode1.run)
160
+ println()
161
+
162
+ val blasExprIntExpr = new Blas1 (new RingIntExpr , new ExprVecOps )
163
+ val resCode2 : Expr [(Array [Int ], Array [Int ]) => Int ] = ' {
164
+ (arr1, arr2) =>
165
+ assert(arr1.length == arr2.length)
166
+ ~ {
167
+ blasExprIntExpr.dot(
168
+ new Vec ('(arr1.size), i => ' (arr1(~ i))),
169
+ new Vec ('(arr2.size), i => ' (arr2(~ i)))
170
+ )
171
+ }
153
172
}
154
- println(code.show)
155
-
173
+ println(resCode2.show)
174
+ println(resCode2.run.apply(arr1, arr2))
175
+ println()
176
+
177
+ val blasStaticIntPVExpr = new Blas1 (new RingPV [Int ](new RingInt , new RingIntExpr ), new StaticVecOps )
178
+ val resCode3 = blasStaticIntPVExpr.dot(
179
+ vec1.map(i => Dyn (i.toExpr)),
180
+ vec2.map(i => Sta (i))
181
+ ).expr
182
+ println(resCode3.show)
183
+ println(resCode3.run)
184
+ println()
185
+
186
+ val blasExprIntPVExpr = new Blas1 (new RingPV [Int ](new RingInt , new RingIntExpr ), new StaticVecOps )
187
+ val resCode4 : Expr [Array [Int ] => Int ] = ' {
188
+ arr =>
189
+ assert(arr.length == ~ vec2.size.toExpr)
190
+ ~ {
191
+ blasExprIntPVExpr.dot(
192
+ new Vec (vec2.size, i => Dyn ('(arr(~i.toExpr)))),
193
+ vec2.map(i => Sta (i))
194
+ ).expr
195
+ }
156
196
157
- {
158
- val vec5 = new Vec [Int , PV [Int ]](5 , i => Dyn (arr1(i).toExpr))
159
- val vec6 = new Vec [Int , PV [Int ]](5 , i => Sta (arr2(i)))
160
- val blasStaticIntExpr = new Blas1 (new RingPV [Int ](new RingInt , new RingIntExpr ), new StaticVecOps )
161
- blasStaticIntExpr.dot(vec5, vec6).expr.show
162
197
}
163
-
164
- {
165
- val code = ' {
166
- val arr3 = Array (1 , 2 , 4 , 8 , 16 )
198
+ println(resCode4.show)
199
+ println(resCode4.run.apply(arr1))
200
+ println()
201
+
202
+ import Complex .isLiftable
203
+ val blasExprComplexIntPVExpr = new Blas1 (new RingPV (new RingComplex (new RingInt ), new RingComplexExpr (new RingIntExpr )), new StaticVecOps )
204
+ val resCode5 : Expr [Array [Complex [Int ]] => Complex [Int ]] = ' {
205
+ arr =>
206
+ assert(arr.length == ~ cmpxVec2.size.toExpr)
167
207
~ {
168
- val vec7 = new Vec [ Int , PV [ Int ]]( 5 , i => Dyn ( ' (arr3(~i.toExpr))))
169
- val vec8 = new Vec [ Int , PV [ Int ]]( 5 , i => Sta (arr2(i )))
170
- val blasExprIntExpr = new Blas1 ( new RingPV [ Int ]( new RingInt , new RingIntExpr ), new StaticVecOps )
171
- blasExprIntExpr.dot(vec7, vec8 ).expr
208
+ blasExprComplexIntPVExpr.dot(
209
+ new Vec (cmpxVec2.size , i => Dyn ( ' (arr(~i.toExpr )))),
210
+ cmpxVec2.map(i => Sta (i) )
211
+ ).expr
172
212
}
173
-
174
- }
175
- println(code.show)
176
213
}
214
+ println(resCode5.show)
215
+ println(resCode5.run.apply(cmpxArr1))
216
+ println()
217
+
177
218
}
178
219
179
220
}
0 commit comments