Skip to content

Commit 9da729c

Browse files
committed
wip
1 parent f2b3837 commit 9da729c

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

tests/run-with-compiler/shonan-hmm.check

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,31 @@ List(25, 30, 20, 43, 44)
254254
})
255255
vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7)))
256256
})
257+
258+
259+
260+
{
261+
val row: scala.Array[scala.Int] = {
262+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
263+
array.update(0, 0)
264+
array.update(1, 0)
265+
array.update(2, 2)
266+
array.update(3, 3)
267+
array.update(4, 5)
268+
array
269+
}
270+
271+
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
272+
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
273+
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
274+
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
275+
vout.update(1, v.apply(2).*(10))
276+
vout.update(2, v.apply(1).*(10))
277+
vout.update(3, {
278+
var sum: scala.Int = 0
279+
scala.Predef.intWrapper(0).until(5).foreach[scala.Unit](((i: scala.Int) => sum = sum.+(v.apply(i).*(row.apply(i)))))
280+
(sum: scala.Int)
281+
})
282+
vout.update(4, v.apply(2).*(3).+(v.apply(4).*(7)))
283+
})
284+
}

tests/run-with-compiler/shonan-hmm/MVmult.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,25 @@ object MVmult {
9393
}
9494

9595
def mvmult_let(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
96-
val (n, m, a2) = amatCopy(a, copy_row_let)
97-
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2)
96+
initRows(a) { rows =>
97+
val (n, m, a2) = amat2(a, rows)
98+
mvmult_abs0(new RingIntOPExpr, new VecRStaOptDynInt(new RingIntPExpr))(n, m, a2)
99+
}
100+
}
101+
102+
def initRows[T](a: Array[Array[Int]])(cont: Array[Expr[Array[Int]]] => Expr[T]): Expr[T] = {
103+
import Lifters._
104+
def loop(i: Int, acc: List[Expr[Array[Int]]]): Expr[T] = {
105+
if (i >= a.length) cont(acc.toArray.reverse)
106+
else if (a(i).count(_ != 0) < VecRStaOptDynInt.threshold) {
107+
val default: Expr[Array[Int]] = '(null.asInstanceOf[Array[Int]]) // never accessed
108+
loop(i + 1, default :: acc)
109+
} else '{
110+
val row = ~a(i).toExpr
111+
~{ loop(i + 1, '(row) :: acc) }
112+
}
113+
}
114+
loop(0, Nil)
98115
}
99116

100117
def amat1(a: Array[Array[Int]], aa: Expr[Array[Array[Int]]]): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = {
@@ -108,6 +125,16 @@ object MVmult {
108125
(n, m, vec)
109126
}
110127

128+
def amat2(a: Array[Array[Int]], refs: Array[Expr[Array[Int]]]): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = {
129+
val n = a.length
130+
val m = a(0).length
131+
val vec: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
132+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
133+
case (Sta(i), Dyn(j)) => Dyn('((~refs(i))(~j)))
134+
}))
135+
(n, m, vec)
136+
}
137+
111138
def amatCopy(a: Array[Array[Int]], copyRow: Array[Int] => (Expr[Int] => Expr[Int])): (Int, Int, Vec[PV[Int], Vec[PV[Int], PV[Int]]]) = {
112139
val n = a.length
113140
val m = a(0).length

tests/run-with-compiler/shonan-hmm/Test.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ object Test {
8080
println()
8181

8282
println(MVmult.mvmult_let1(a).show)
83-
// println()
84-
// println()
85-
// println()
83+
println()
84+
println()
85+
println()
8686

87-
// println(MVmult.mvmult_let(a).show)
87+
println(MVmult.mvmult_let(a).show)
8888
}
8989
}
9090

tests/run-with-compiler/shonan-hmm/VecROp.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,16 @@ class VecRStaDyn[T : Type : Liftable](r: Ring[PV[T]]) extends VecROp[PV[Int], PV
6666
override def toString(): String = s"VecRStaDim($r)"
6767
}
6868

69+
object VecRStaOptDynInt {
70+
val threshold = 3
71+
}
72+
6973
class VecRStaOptDynInt(r: Ring[PV[Int]]) extends VecRStaDyn(r) {
7074
val M: VecROp[PV[Int], PV[Int], Expr[Unit]] = new VecRStaDyn(r)
7175

72-
private val threshold = 3
73-
7476
override def reduce: ((PV[Int], PV[Int]) => PV[Int], PV[Int], Vec[PV[Int], PV[Int]]) => PV[Int] = (plus, zero, vec) => vec match {
7577
case Vec(Sta(n), vecf) =>
76-
if (count_non_zeros(n, vecf) < threshold) M.reduce(plus, zero, vec)
78+
if (count_non_zeros(n, vecf) < VecRStaOptDynInt.threshold) M.reduce(plus, zero, vec)
7779
else M.reduce(plus, zero, Vec(Dyn(n.toExpr), vecf))
7880
case _ => M.reduce(plus, zero, vec)
7981
}

0 commit comments

Comments
 (0)