Skip to content

Commit 98606aa

Browse files
committed
Add more rings
1 parent 09353b7 commit 98606aa

File tree

2 files changed

+174
-93
lines changed

2 files changed

+174
-93
lines changed
Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,54 @@
1-
21
2-
Complex(7,9)
3-
0.+(1.*(1)).+(2.*(0)).+(4.*(1)).+(8.*(0)).+(16.*(1))
4-
{
5-
val arr3: scala.Array[scala.Int] = scala.Array.apply(1, 2, 4, 8, 16)
6-
val arr4: scala.Array[scala.Int] = scala.Array.apply(1, 0, 1, 0, 1)
1+
10
2+
3+
Complex(3,2)
4+
5+
0.+(0.*(1)).+(1.*(0)).+(2.*(1)).+(4.*(0)).+(8.*(1))
6+
10
7+
8+
((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => {
9+
dotty.DottyPredef.assert(arr1.length.==(arr2.length))
710
var sum: scala.Int = 0
8-
scala.Predef.intWrapper(0).until(scala.Predef.intArrayOps(arr3).size).foreach[scala.Unit](((i: scala.Int) => sum = sum.+(arr3.apply(i).*(arr4.apply(i)))))
11+
var i: scala.Int = 0
12+
while (i.<(scala.Predef.intArrayOps(arr1).size)) {
13+
sum = sum.+(arr1.apply(i).*(arr2.apply(i)))
14+
i = i.+(1)
15+
}
916
(sum: scala.Int)
10-
}
11-
{
12-
val arr3: scala.Array[scala.Int] = scala.Array.apply(1, 2, 4, 8, 16)
13-
arr3.apply(0).+(arr3.apply(2)).+(arr3.apply(4))
14-
}
17+
})
18+
10
19+
20+
0.+(2).+(8)
21+
10
22+
23+
((arr: scala.Array[scala.Int]) => {
24+
dotty.DottyPredef.assert(arr.length.==(5))
25+
arr.apply(0).+(arr.apply(2)).+(arr.apply(4))
26+
})
27+
10
28+
29+
((arr: scala.Array[Complex[scala.Int]]) => {
30+
dotty.DottyPredef.assert(arr.length.==(4))
31+
Complex.apply[scala.Int]({
32+
Complex.apply[scala.Int](arr.apply(0).re.*({
33+
Complex.apply[scala.Int](0, 1)
34+
}.re).-(arr.apply(0).im.*({
35+
Complex.apply[scala.Int](0, 1)
36+
}.im)), arr.apply(0).re.*({
37+
Complex.apply[scala.Int](0, 1)
38+
}.im).+(arr.apply(0).im.*({
39+
Complex.apply[scala.Int](0, 1)
40+
}.re)))
41+
}.re.+(arr.apply(3).re), {
42+
Complex.apply[scala.Int](arr.apply(0).re.*({
43+
Complex.apply[scala.Int](0, 1)
44+
}.re).-(arr.apply(0).im.*({
45+
Complex.apply[scala.Int](0, 1)
46+
}.im)), arr.apply(0).re.*({
47+
Complex.apply[scala.Int](0, 1)
48+
}.im).+(arr.apply(0).im.*({
49+
Complex.apply[scala.Int](0, 1)
50+
}.re)))
51+
}.im.+(arr.apply(3).im))
52+
})
53+
Complex(3,2)
54+

tests/run-with-compiler/shonan-hmm-simple.scala

Lines changed: 122 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,34 @@ trait Ring[T] {
1010

1111
class RingInt extends Ring[Int] {
1212
val zero = 0
13-
val one = 1
14-
val add = (x, y) => x + y
13+
val one = 1
14+
val add = (x, y) => x + y
1515
val sub = (x, y) => x - y
1616
val mul = (x, y) => x * y
1717
}
1818

1919
class RingIntExpr extends Ring[Expr[Int]] {
2020
val zero = '(0)
21-
val one = '(1)
22-
val add = (x, y) => '(~x + ~y)
21+
val one = '(1)
22+
val add = (x, y) => '(~x + ~y)
2323
val sub = (x, y) => '(~x - ~y)
2424
val mul = (x, y) => '(~x * ~y)
2525
}
2626

2727
class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] {
2828
val zero = Complex(u.zero, u.zero)
2929
val one = Complex(u.one, u.zero)
30-
val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im))
31-
val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im))
32-
val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
30+
val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im))
31+
val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im))
32+
val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re)))
33+
}
34+
35+
class RingComplexExpr[U: Type](u: Ring[Expr[U]]) extends Ring[Expr[Complex[U]]] {
36+
val zero = '(Complex(~u.zero, ~u.zero))
37+
val one = '(Complex(~u.one, ~u.zero))
38+
val add = (x, y) => '(Complex(~u.add('((~x).re), '((~y).re)), ~u.add('((~x).im), '((~y).im))))
39+
val sub = (x, y) => '(Complex(~u.sub('((~x).re), '((~y).re)), ~u.sub('((~x).im), '((~y).im))))
40+
val mul = (x, y) => '(Complex(~u.sub(u.mul('((~x).re), '((~y).re)), u.mul('((~x).im), '((~y).im))), ~u.add(u.mul('((~x).re), '((~y).im)), u.mul('((~x).im), '((~y).re)))))
3341
}
3442

3543
sealed trait PV[T] {
@@ -42,43 +50,48 @@ case class Dyn[T](x: Expr[T]) extends PV[T] {
4250
def expr(implicit l: Liftable[T]): Expr[T] = x
4351
}
4452

45-
case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends Ring[PV[U]] {
46-
type T = PV[U]
47-
48-
import staRing._
49-
import dynRing._
50-
51-
val zero: T = Sta(staRing.zero)
52-
val one: T = Sta(staRing.one)
53-
val add = (x: T, y: T) => (x, y) match {
54-
case (Sta(staRing.zero), x) => x
55-
case (x, Sta(staRing.zero)) => x
56-
case (Sta(x), Sta(y)) => Sta(staRing.add(x, y))
57-
case (x, y) => Dyn(dynRing.add(x.expr, y.expr))
53+
class RingPV[U: Liftable](u: Ring[U], eu: Ring[Expr[U]]) extends Ring[PV[U]] {
54+
val zero: PV[U] = Sta(u.zero)
55+
val one: PV[U] = Sta(u.one)
56+
val add = (x: PV[U], y: PV[U]) => (x, y) match {
57+
case (Sta(u.zero), x) => x
58+
case (x, Sta(u.zero)) => x
59+
case (Sta(x), Sta(y)) => Sta(u.add(x, y))
60+
case (x, y) => Dyn(eu.add(x.expr, y.expr))
5861
}
59-
val sub = (x: T, y: T) => (x, y) match {
60-
case (Sta(staRing.zero), x) => x
61-
case (x, Sta(staRing.zero)) => x
62-
case (Sta(x), Sta(y)) => Sta(staRing.sub(x, y))
63-
case (x, y) => Dyn(dynRing.sub(x.expr, y.expr))
62+
val sub = (x: PV[U], y: PV[U]) => (x, y) match {
63+
case (Sta(u.zero), x) => x
64+
case (x, Sta(u.zero)) => x
65+
case (Sta(x), Sta(y)) => Sta(u.sub(x, y))
66+
case (x, y) => Dyn(eu.sub(x.expr, y.expr))
6467
}
65-
val mul = (x: T, y: T) => (x, y) match {
66-
case (Sta(staRing.zero), _) => Sta(staRing.zero)
67-
case (_, Sta(staRing.zero)) => Sta(staRing.zero)
68-
case (Sta(staRing.one), x) => x
69-
case (x, Sta(staRing.one)) => x
70-
case (Sta(x), Sta(y)) => Sta(staRing.mul(x, y))
71-
case (x, y) => Dyn(dynRing.mul(x.expr, y.expr))
68+
val mul = (x: PV[U], y: PV[U]) => (x, y) match {
69+
case (Sta(u.zero), _) => Sta(u.zero)
70+
case (_, Sta(u.zero)) => Sta(u.zero)
71+
case (Sta(u.one), x) => x
72+
case (x, Sta(u.one)) => x
73+
case (Sta(x), Sta(y)) => Sta(u.mul(x, y))
74+
case (x, y) => Dyn(eu.mul(x.expr, y.expr))
7275
}
7376
}
7477

78+
7579
case class Complex[T](re: T, im: T)
7680

81+
object Complex {
82+
implicit def isLiftable[T: Type: Liftable]: Liftable[Complex[T]] = new Liftable[Complex[T]] {
83+
def toExpr(comp: Complex[T]): Expr[Complex[T]] = '(Complex(~comp.re.toExpr, ~comp.im.toExpr))
84+
}
85+
}
86+
7787
case class Vec[Idx, T](size: Idx, get: Idx => T) {
7888
def map[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i)))
7989
def zipWith[U, V](other: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = Vec(size, i => f(get(i), other.get(i)))
8090
}
8191

92+
object Vec {
93+
def from[T](elems: T*): Vec[Int, T] = new Vec(elems.size, i => elems(i))
94+
}
8295

8396
trait VecOps[Idx, T] {
8497
val reduce: ((T, T) => T, T, Vec[Idx, T]) => T
@@ -93,20 +106,14 @@ class StaticVecOps[T] extends VecOps[Int, T] {
93106
}
94107
}
95108

96-
class StaticVecOptOps[T] extends VecOps[Int, T] {
97-
val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => {
98-
var sum = zero
99-
for (i <- 0 until vec.size)
100-
sum = plus(sum, vec.get(i))
101-
sum
102-
}
103-
}
104-
105109
class ExprVecOps[T: Type] extends VecOps[Expr[Int], Expr[T]] {
106110
val reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = (plus, zero, vec) => '{
107111
var sum = ~zero
108-
for (i <- 0 until ~vec.size)
112+
var i = 0
113+
while (i < ~vec.size) {
109114
sum = ~{ plus('(sum), vec.get('(i))) }
115+
i += 1
116+
}
110117
sum
111118
}
112119
}
@@ -120,60 +127,94 @@ object Test {
120127
implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make
121128

122129
def main(args: Array[String]): Unit = {
123-
val arr1 = Array(1, 2, 4, 8, 16)
130+
val arr1 = Array(0, 1, 2, 4, 8)
124131
val arr2 = Array(1, 0, 1, 0, 1)
132+
val cmpxArr1 = Array(Complex(1, 0), Complex(2, 3), Complex(0, 2), Complex(3, 1))
133+
val cmpxArr2 = Array(Complex(0, 1), Complex(0, 0), Complex(0, 0), Complex(1, 0))
125134

126135
val vec1 = new Vec(arr1.size, i => arr1(i))
127136
val vec2 = new Vec(arr2.size, i => arr2(i))
137+
val cmpxVec1 = new Vec(cmpxArr1.size, i => cmpxArr1(i))
138+
val cmpxVec2 = new Vec(cmpxArr2.size, i => cmpxArr2(i))
139+
128140
val blasInt = new Blas1(new RingInt, new StaticVecOps)
129-
println(blasInt.dot(vec1, vec2))
141+
val res1 = blasInt.dot(vec1, vec2)
142+
println(res1)
143+
println()
130144

131-
val vec3 = new Vec(arr1.size, i => Complex(2, arr2(i)))
132-
val vec4 = new Vec(arr2.size, i => Complex(1, arr2(i)))
133145
val blasComplexInt = new Blas1(new RingComplex(new RingInt), new StaticVecOps)
134-
println(blasComplexInt.dot(vec3, vec4))
146+
val res2 = blasComplexInt.dot(
147+
cmpxVec1,
148+
cmpxVec2
149+
)
150+
println(res2)
151+
println()
135152

136-
val vec5 = new Vec(5, i => arr1(i).toExpr)
137-
val vec6 = new Vec(5, i => arr2(i).toExpr)
138153
val blasStaticIntExpr = new Blas1(new RingIntExpr, new StaticVecOps)
139-
println(blasStaticIntExpr.dot(vec5, vec6).show)
140-
141-
142-
143-
val code = '{
144-
val arr3 = Array(1, 2, 4, 8, 16)
145-
val arr4 = Array(1, 0, 1, 0, 1)
146-
~{
147-
val vec7 = new Vec('(arr3.size), i => '(arr3(~i)))
148-
val vec8 = new Vec('(arr4.size), i => '(arr4(~i)))
149-
val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps)
150-
blasExprIntExpr.dot(vec7, vec8)
151-
}
152-
154+
val resCode1 = blasStaticIntExpr.dot(
155+
vec1.map(_.toExpr),
156+
vec2.map(_.toExpr)
157+
)
158+
println(resCode1.show)
159+
println(resCode1.run)
160+
println()
161+
162+
val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps)
163+
val resCode2: Expr[(Array[Int], Array[Int]) => Int] = '{
164+
(arr1, arr2) =>
165+
assert(arr1.length == arr2.length)
166+
~{
167+
blasExprIntExpr.dot(
168+
new Vec('(arr1.size), i => '(arr1(~i))),
169+
new Vec('(arr2.size), i => '(arr2(~i)))
170+
)
171+
}
153172
}
154-
println(code.show)
155-
173+
println(resCode2.show)
174+
println(resCode2.run.apply(arr1, arr2))
175+
println()
176+
177+
val blasStaticIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
178+
val resCode3 = blasStaticIntPVExpr.dot(
179+
vec1.map(i => Dyn(i.toExpr)),
180+
vec2.map(i => Sta(i))
181+
).expr
182+
println(resCode3.show)
183+
println(resCode3.run)
184+
println()
185+
186+
val blasExprIntPVExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
187+
val resCode4: Expr[Array[Int] => Int] = '{
188+
arr =>
189+
assert(arr.length == ~vec2.size.toExpr)
190+
~{
191+
blasExprIntPVExpr.dot(
192+
new Vec(vec2.size, i => Dyn('(arr(~i.toExpr)))),
193+
vec2.map(i => Sta(i))
194+
).expr
195+
}
156196

157-
{
158-
val vec5 = new Vec[Int, PV[Int]](5, i => Dyn(arr1(i).toExpr))
159-
val vec6 = new Vec[Int, PV[Int]](5, i => Sta(arr2(i)))
160-
val blasStaticIntExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
161-
blasStaticIntExpr.dot(vec5, vec6).expr.show
162197
}
163-
164-
{
165-
val code = '{
166-
val arr3 = Array(1, 2, 4, 8, 16)
198+
println(resCode4.show)
199+
println(resCode4.run.apply(arr1))
200+
println()
201+
202+
import Complex.isLiftable
203+
val blasExprComplexIntPVExpr = new Blas1(new RingPV(new RingComplex(new RingInt), new RingComplexExpr(new RingIntExpr)), new StaticVecOps)
204+
val resCode5: Expr[Array[Complex[Int]] => Complex[Int]] = '{
205+
arr =>
206+
assert(arr.length == ~cmpxVec2.size.toExpr)
167207
~{
168-
val vec7 = new Vec[Int, PV[Int]](5, i => Dyn('(arr3(~i.toExpr))))
169-
val vec8 = new Vec[Int, PV[Int]](5, i => Sta(arr2(i)))
170-
val blasExprIntExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps)
171-
blasExprIntExpr.dot(vec7, vec8).expr
208+
blasExprComplexIntPVExpr.dot(
209+
new Vec(cmpxVec2.size, i => Dyn('(arr(~i.toExpr)))),
210+
cmpxVec2.map(i => Sta(i))
211+
).expr
172212
}
173-
174-
}
175-
println(code.show)
176213
}
214+
println(resCode5.show)
215+
println(resCode5.run.apply(cmpxArr1))
216+
println()
217+
177218
}
178219

179220
}

0 commit comments

Comments
 (0)