Skip to content

Commit a995e3a

Browse files
committed
Fix complex PV blass
1 parent 98606aa commit a995e3a

File tree

2 files changed

+10
-39
lines changed

2 files changed

+10
-39
lines changed
Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
10
22

3-
Complex(3,2)
3+
Complex(4,3)
44

55
0.+(0.*(1)).+(1.*(0)).+(2.*(1)).+(4.*(0)).+(8.*(1))
66
10
@@ -28,27 +28,7 @@ Complex(3,2)
2828

2929
((arr: scala.Array[Complex[scala.Int]]) => {
3030
dotty.DottyPredef.assert(arr.length.==(4))
31-
Complex.apply[scala.Int]({
32-
Complex.apply[scala.Int](arr.apply(0).re.*({
33-
Complex.apply[scala.Int](0, 1)
34-
}.re).-(arr.apply(0).im.*({
35-
Complex.apply[scala.Int](0, 1)
36-
}.im)), arr.apply(0).re.*({
37-
Complex.apply[scala.Int](0, 1)
38-
}.im).+(arr.apply(0).im.*({
39-
Complex.apply[scala.Int](0, 1)
40-
}.re)))
41-
}.re.+(arr.apply(3).re), {
42-
Complex.apply[scala.Int](arr.apply(0).re.*({
43-
Complex.apply[scala.Int](0, 1)
44-
}.re).-(arr.apply(0).im.*({
45-
Complex.apply[scala.Int](0, 1)
46-
}.im)), arr.apply(0).re.*({
47-
Complex.apply[scala.Int](0, 1)
48-
}.im).+(arr.apply(0).im.*({
49-
Complex.apply[scala.Int](0, 1)
50-
}.re)))
51-
}.im.+(arr.apply(3).im))
31+
Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2)))
5232
})
53-
Complex(3,2)
33+
Complex(4,3)
5434

tests/run-with-compiler/shonan-hmm-simple.scala

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] {
3232
val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
3333
}
3434

35-
class RingComplexExpr[U: Type](u: Ring[Expr[U]]) extends Ring[Expr[Complex[U]]] {
36-
val zero = '(Complex(~u.zero, ~u.zero))
37-
val one = '(Complex(~u.one, ~u.zero))
38-
val add = (x, y) => '(Complex(~u.add('((~x).re), '((~y).re)), ~u.add('((~x).im), '((~y).im))))
39-
val sub = (x, y) => '(Complex(~u.sub('((~x).re), '((~y).re)), ~u.sub('((~x).im), '((~y).im))))
40-
val mul = (x, y) => '(Complex(~u.sub(u.mul('((~x).re), '((~y).re)), u.mul('((~x).im), '((~y).im))), ~u.add(u.mul('((~x).re), '((~y).im)), u.mul('((~x).im), '((~y).re)))))
41-
}
42-
4335
sealed trait PV[T] {
4436
def expr(implicit l: Liftable[T]): Expr[T]
4537
}
@@ -60,7 +52,6 @@ class RingPV[U: Liftable](u: Ring[U], eu: Ring[Expr[U]]) extends Ring[PV[U]] {
6052
case (x, y) => Dyn(eu.add(x.expr, y.expr))
6153
}
6254
val sub = (x: PV[U], y: PV[U]) => (x, y) match {
63-
case (Sta(u.zero), x) => x
6455
case (x, Sta(u.zero)) => x
6556
case (Sta(x), Sta(y)) => Sta(u.sub(x, y))
6657
case (x, y) => Dyn(eu.sub(x.expr, y.expr))
@@ -130,7 +121,7 @@ object Test {
130121
val arr1 = Array(0, 1, 2, 4, 8)
131122
val arr2 = Array(1, 0, 1, 0, 1)
132123
val cmpxArr1 = Array(Complex(1, 0), Complex(2, 3), Complex(0, 2), Complex(3, 1))
133-
val cmpxArr2 = Array(Complex(0, 1), Complex(0, 0), Complex(0, 0), Complex(1, 0))
124+
val cmpxArr2 = Array(Complex(0, 1), Complex(0, 0), Complex(0, 1), Complex(2, 0))
134125

135126
val vec1 = new Vec(arr1.size, i => arr1(i))
136127
val vec2 = new Vec(arr2.size, i => arr2(i))
@@ -200,21 +191,21 @@ object Test {
200191
println()
201192

202193
import Complex.isLiftable
203-
val blasExprComplexIntPVExpr = new Blas1(new RingPV(new RingComplex(new RingInt), new RingComplexExpr(new RingIntExpr)), new StaticVecOps)
194+
val blasExprComplexPVInt = new Blas1[Int, Complex[PV[Int]]](new RingComplex(new RingPV[Int](new RingInt, new RingIntExpr)), new StaticVecOps)
204195
val resCode5: Expr[Array[Complex[Int]] => Complex[Int]] = '{
205196
arr =>
206197
assert(arr.length == ~cmpxVec2.size.toExpr)
207198
~{
208-
blasExprComplexIntPVExpr.dot(
209-
new Vec(cmpxVec2.size, i => Dyn('(arr(~i.toExpr)))),
210-
cmpxVec2.map(i => Sta(i))
211-
).expr
199+
val cpx = blasExprComplexPVInt.dot(
200+
new Vec(cmpxVec2.size, i => Complex(Dyn('(arr(~i.toExpr).re)), Dyn('(arr(~i.toExpr).im)))),
201+
new Vec(cmpxVec2.size, i => Complex(Sta(cmpxVec2.get(i).re), Sta(cmpxVec2.get(i).im)))
202+
)
203+
'(Complex(~cpx.re.expr, ~cpx.im.expr))
212204
}
213205
}
214206
println(resCode5.show)
215207
println(resCode5.run.apply(cmpxArr1))
216208
println()
217-
218209
}
219210

220211
}

0 commit comments

Comments
 (0)