Skip to content

Commit f835e19

Browse files
committed
wip
1 parent bd278cf commit f835e19

File tree

3 files changed

+85
-25
lines changed

3 files changed

+85
-25
lines changed

tests/run-with-compiler/shonan-hmm.check

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,27 @@ List(25, 30, 20, 43, 44)
230230
vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7)))
231231
})
232232
}
233+
234+
235+
236+
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
237+
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
238+
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
239+
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
240+
vout.update(1, v.apply(2).*(10))
241+
vout.update(2, v.apply(1).*(10))
242+
vout.update(3, {
243+
var sum: scala.Int = 0
244+
scala.Predef.intWrapper(0).until(5).foreach[scala.Unit](((i: scala.Int) => sum = sum.+(v.apply(i).*({
245+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
246+
array.update(0, 0)
247+
array.update(1, 0)
248+
array.update(2, 2)
249+
array.update(3, 3)
250+
array.update(4, 5)
251+
array
252+
}.apply(i)))))
253+
(sum: scala.Int)
254+
})
255+
vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7)))
256+
})

tests/run-with-compiler/shonan-hmm/MVmult.scala

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,56 +55,84 @@ object MVmult {
5555
}
5656

5757
def mvmult_ac(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
58-
val n = a.length
59-
val m = a(0).length
6058
import Lifters._
6159
'{
6260
val arr = ~a.toExpr
6361
~{
64-
val a2: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
65-
case (Sta(i), Sta(j)) => Sta(a(i)(j))
66-
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
67-
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
68-
}))
69-
mvmult_abs0(new RingIntPExpr, new VecRStaDyn(new RingIntPExpr))(a.length, a(0).length, a2)
62+
val (n, m, a2) = amat1(a, '(arr))
63+
mvmult_abs0(new RingIntPExpr, new VecRStaDyn(new RingIntPExpr))(n, m, a2)
7064
}
7165
}
7266
}
7367

7468
def mvmult_opt(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
75-
val n = a.length
76-
val m = a(0).length
7769
import Lifters._
7870
'{
7971
val arr = ~a.toExpr
8072
~{
81-
val a2: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
82-
case (Sta(i), Sta(j)) => Sta(a(i)(j))
83-
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
84-
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
85-
}))
86-
mvmult_abs0(new RingIntOPExpr, new VecRStaDyn(new RingIntPExpr))(a.length, a(0).length, a2)
73+
val (n, m, a2) = amat1(a, '(arr))
74+
mvmult_abs0(new RingIntOPExpr, new VecRStaDyn(new RingIntPExpr))(n, m, a2)
8775
}
8876
}
8977
}
9078

9179
def mvmult_roll(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
92-
val n = a.length
93-
val m = a(0).length
9480
import Lifters._
9581
'{
9682
val arr = ~a.toExpr
9783
~{
98-
val a2: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
99-
case (Sta(i), Sta(j)) => Sta(a(i)(j))
100-
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
101-
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
102-
}))
103-
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(a.length, a(0).length, a2)
84+
val (n, m, a2) = amat1(a, '(arr))
85+
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2)
10486
}
10587
}
10688
}
10789

90+
def mvmult_let1(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
91+
val (n, m, a2) = amatCopy(a, copy_row1)
92+
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2)
93+
}
94+
95+
def mvmult_let(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
96+
val (n, m, a2) = amatCopy(a, copy_row_let)
97+
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2)
98+
}
99+
100+
def amat1(a: Array[Array[Int]], aa: Expr[Array[Array[Int]]]): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = {
101+
val n = a.length
102+
val m = a(0).length
103+
val vec: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
104+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
105+
case (Sta(i), Dyn(j)) => Dyn('((~aa)(~i.toExpr)(~j)))
106+
case (i, j) => Dyn('{ (~aa)(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
107+
}))
108+
(n, m, vec)
109+
}
110+
111+
def amatCopy(a: Array[Array[Int]], copyRow: Array[Int] => (Expr[Int] => Expr[Int])): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = {
112+
val n = a.length
113+
val m = a(0).length
114+
val vec: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
115+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
116+
case (Sta(i), Dyn(j)) =>
117+
val defrec = copyRow(a(i))
118+
Dyn(defrec(j))
119+
case (i, j) => ???
120+
}))
121+
(n, m, vec)
122+
}
123+
124+
def copy_row1: Array[Int] => (Expr[Int] => Expr[Int]) = v => {
125+
import Lifters._
126+
val arr = v.toExpr
127+
i => '{ (~arr).apply(~i) }
128+
}
129+
130+
def copy_row_let: Array[Int] => (Expr[Int] => Expr[Int]) = v => {
131+
import Lifters._
132+
val arr: Expr[Array[Int]] = ??? // FIXME used genlet v.toExpr
133+
i => '{ (~arr).apply(~i) }
134+
}
135+
108136
private def mvmult_abs0(ring: Ring[PV[Int]], vecOp: VecROp[PV[Int], PV[Int], Expr[Unit]])(n: Int, m: Int, a: Vec[PV[Int], Vec[PV[Int], PV[Int]]]): Expr[(Array[Int], Array[Int]) => Unit] = {
109137
'{
110138
(vout, v) => {
@@ -120,5 +148,4 @@ object MVmult {
120148
}
121149
}
122150

123-
124151
}

tests/run-with-compiler/shonan-hmm/Test.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,16 @@ object Test {
7575
println()
7676

7777
println(MVmult.mvmult_roll(a).show)
78+
println()
79+
println()
80+
println()
81+
82+
println(MVmult.mvmult_let1(a).show)
83+
// println()
84+
// println()
85+
// println()
7886

87+
// println(MVmult.mvmult_let(a).show)
7988
}
8089
}
8190

0 commit comments

Comments
 (0)