Skip to content

Commit 2211d5b

Browse files
committed
wip
1 parent e6f6b07 commit 2211d5b

File tree

7 files changed

+196
-46
lines changed

7 files changed

+196
-46
lines changed

tests/run-with-compiler/shonan-hmm.check

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
dafd
21
Complex(0,10)
32
Complex(1.*(4).-(2.*(2)), 1.*(2).+(2.*(4)))
43
List(Complex(2,0), Complex(-4,4), Complex(-2,6))
@@ -23,7 +22,8 @@ List(25, 30, 20, 43, 44)
2322

2423

2524
((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => {
26-
dotty.DottyPredef.assert(3.==(vout.length).&&(2.==(v.length)))
25+
if (3.!=(vout.length)) throw new scala.IndexOutOfBoundsException("3") else ()
26+
if (2.!=(v.length)) throw new scala.IndexOutOfBoundsException("2") else ()
2727
vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1))))
2828
vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1))))
2929
vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1))))
@@ -32,10 +32,61 @@ List(25, 30, 20, 43, 44)
3232

3333

3434
{
35-
val arr: scala.Array[scala.Array[scala.Int]] = scala.Array.apply[scala.Array[scala.Int]](scala.Array.apply(5, 0, 0, 5, 0), scala.Array.apply(0, 0, 10, 0, 0), scala.Array.apply(0, 10, 0, 0, 0), scala.Array.apply(0, 0, 2, 3, 5), scala.Array.apply(0, 0, 3, 0, 7))(scala.reflect.ClassTag.apply[scala.Int](java.lang.Integer.TYPE).wrap)
35+
val arr: scala.Array[scala.Array[scala.Int]] = {
36+
val array: scala.Array[scala.Array[scala.Int]] = dotty.runtime.Arrays.newGenericArray[scala.Array[scala.Int]](5)({
37+
scala.reflect.ClassTag.apply[scala.Array[scala.Int]](scala.Predef.classOf[scala.Array[scala.Int]])
38+
})
39+
array.update(0, {
40+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
41+
array.update(0, 5)
42+
array.update(1, 0)
43+
array.update(2, 0)
44+
array.update(3, 5)
45+
array.update(4, 0)
46+
array
47+
})
48+
array.update(1, {
49+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
50+
array.update(0, 0)
51+
array.update(1, 0)
52+
array.update(2, 10)
53+
array.update(3, 0)
54+
array.update(4, 0)
55+
array
56+
})
57+
array.update(2, {
58+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
59+
array.update(0, 0)
60+
array.update(1, 10)
61+
array.update(2, 0)
62+
array.update(3, 0)
63+
array.update(4, 0)
64+
array
65+
})
66+
array.update(3, {
67+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
68+
array.update(0, 0)
69+
array.update(1, 0)
70+
array.update(2, 2)
71+
array.update(3, 3)
72+
array.update(4, 5)
73+
array
74+
})
75+
array.update(4, {
76+
val array: scala.Array[scala.Int] = new scala.Array[scala.Int](5)
77+
array.update(0, 0)
78+
array.update(1, 0)
79+
array.update(2, 3)
80+
array.update(3, 0)
81+
array.update(4, 7)
82+
array
83+
})
84+
array
85+
}
3686

3787
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
38-
dotty.DottyPredef.assert(5.==(vout.length).&&(5.==(v.length)))
88+
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
89+
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
3990
vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0)))
4091
vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
4192
vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
import UnrolledExpr._
3+
4+
import scala.reflect.ClassTag
5+
import scala.quoted._
6+
7+
object Lifters {
8+
9+
implicit def ClassTagIsLiftable[T : Type](implicit ct: ClassTag[T]): Liftable[ClassTag[T]] =
10+
ct => '(ClassTag(~ct.runtimeClass.toExpr))
11+
12+
implicit def ArrayIsLiftable[T : Type: ClassTag](implicit l: Liftable[T]): Liftable[Array[T]] = arr => '{
13+
val array = new Array[T](~arr.length.toExpr)(~implicitly[ClassTag[T]].toExpr)
14+
~initArray(arr, '(array))
15+
}
16+
17+
implicit def IntArrayIsLiftable: Liftable[Array[Int]] = arr => '{
18+
val array = new Array[Int](~arr.length.toExpr)
19+
~initArray(arr, '(array))
20+
}
21+
22+
private def initArray[T : Liftable](arr: Array[T], array: Expr[Array[T]]): Expr[Array[T]] = {
23+
UnrolledExpr.block(
24+
arr.zipWithIndex.map {
25+
case (x, i) => '{ (~array)(~i.toExpr) = ~x.toExpr }
26+
}.toList,
27+
array)
28+
}
29+
30+
}

tests/run-with-compiler/shonan-hmm/MVmult.scala

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object MVmult {
3131
val a_ = Vec('(n), (i: Expr[Int]) => Vec('(m), (j: Expr[Int]) => '{ a(~i)(~j) } ))
3232
val v_ = Vec('(m), (i: Expr[Int]) => '(v(~i)))
3333

34-
val MV = new MVmult[Expr[Int], Expr[Int], Expr[Unit]](RingIntExpr, new VecRDyn(RingIntExpr))
34+
val MV = new MVmult[Expr[Int], Expr[Int], Expr[Unit]](RingIntExpr, new VecRDyn)
3535
MV.mvmult(vout_, a_, v_)
3636
}
3737
}
@@ -41,7 +41,8 @@ object MVmult {
4141
val MV = new MVmult[Int, Expr[Int], Expr[Unit]](RingIntExpr, new VecRStaDim(RingIntExpr))
4242
'{
4343
(vout, a, v) => {
44-
assert (~n.toExpr == vout.length && ~m.toExpr == v.length)
44+
if (~n.toExpr != vout.length) throw new IndexOutOfBoundsException(~n.toString.toExpr)
45+
if (~m.toExpr != v.length) throw new IndexOutOfBoundsException(~m.toString.toExpr)
4546
~{
4647
val vout_ = OVec(n, (i, x: Expr[Int]) => '(vout(~i.toExpr) = ~x))
4748
val a_ = Vec(n, i => Vec(m, j => '{ a(~i.toExpr)(~j.toExpr) } ))
@@ -56,32 +57,50 @@ object MVmult {
5657
def mvmult_ac(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
5758
val n = a.length
5859
val m = a(0).length
60+
import Lifters._
61+
'{
62+
val arr = ~a.toExpr
63+
~{
64+
val a2: Vec[PV[Int], Vec[PV[Int], PV[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => (i, j) match {
65+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
66+
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
67+
case (i, j) => Dyn( '{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
68+
}))
69+
mvmult_abs(a.length, a(0).length, a2)
70+
}
71+
}
72+
}
5973

60-
// Array lifters
61-
62-
74+
private def mvmult_abs(n: Int, m: Int, a: Vec[PV[Int], Vec[PV[Int], PV[Int]]]): Expr[(Array[Int], Array[Int]) => Unit] = {
6375
'{
64-
val arr = Array( // FIXMR lift a
65-
Array( 5, 0, 0, 5, 0),
66-
Array( 0, 0, 10, 0, 0),
67-
Array( 0, 10, 0, 0, 0),
68-
Array( 0, 0, 2, 3, 5),
69-
Array( 0, 0, 3, 0, 7)
70-
)
7176
(vout, v) => {
72-
assert (~n.toExpr == vout.length && ~m.toExpr == v.length)
77+
if (~n.toExpr != vout.length) throw new IndexOutOfBoundsException(~n.toString.toExpr)
78+
if (~m.toExpr != v.length) throw new IndexOutOfBoundsException(~m.toString.toExpr)
7379
~{
74-
val vout_ : OVec[PV[Int], Expr[Int], Expr[Unit]] = OVec(Sta(n), (i, x) => '(vout(~Dyns.dyni(i)) = ~x))
75-
val a2: Vec[PV[Int], Vec[PV[Int], Expr[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => Dyns.dyn((i, j) match {
76-
case (Sta(i), Sta(j)) => Sta(a(i)(j))
77-
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
78-
case (i, j) => Dyn('{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
79-
})))
80-
val v_ : Vec[PV[Int], Expr[Int]] = Vec(Sta(m), i => '(v(~Dyns.dyni(i))))
81-
val MV = new MVmult[PV[Int], Expr[Int], Expr[Unit]](RingIntExpr, new VecRStaDyn(RingIntExpr))
82-
MV.mvmult(vout_, a2, v_)
80+
val vout_ : OVec[PV[Int], PV[Int], Expr[Unit]] = OVec(Sta(n), (i, x) => '(vout(~Dyns.dyni(i)) = ~Dyns.dyn(x)))
81+
val v_ : Vec[PV[Int], PV[Int]] = Vec(Sta(m), i => Dyn('(v(~Dyns.dyni(i)))))
82+
val MV = new MVmult[PV[Int], PV[Int], Expr[Unit]](new RingIntPExpr, new VecRStaDyn(new RingIntPExpr))
83+
MV.mvmult(vout_, a, v_)
8384
}
8485
}
8586
}
8687
}
88+
89+
90+
91+
// let mvmult_abs : _ →
92+
// amat → (float array → float array → unit) code =
93+
// fun mvmult → fun {n;m;a} →
94+
// .<fun vout v →
95+
// assert (n = Array.length vout && m = Array.length v);
96+
// .~(let vout = OVec (Sta n, fun i v → .<vout.(.~(dyni i)) ← .~(dynf v)>.) in
97+
// let v = Vec (Sta m, fun j → Dyn .<v.(.~(dyni j))>.) in
98+
// mvmult vout a v)
99+
// >.
100+
// val mvmult_abs :
101+
// ((int pv, float pv, unit code) Vector.ovec →
102+
// (int pv, (int pv, float pv) Vector.vec) Vector.vec →
103+
// (int pv, float pv) Vector.vec → unit code) →
104+
// amat → (float array → float array → unit) code = <fun>
105+
87106
}

tests/run-with-compiler/shonan-hmm/Ring.scala

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import scala.quoted._
33

44
trait Ring[T] {
5-
val zero: T
5+
def zero: T
66
val one: T
7-
val add: (x: T, y: T) => T
8-
val sub: (x: T, y: T) => T
9-
val mul: (x: T, y: T) => T
7+
def add: (x: T, y: T) => T
8+
def sub: (x: T, y: T) => T
9+
def mul: (x: T, y: T) => T
1010

1111
implicit class Ops(x: T) {
1212
def +(y: T): T = add(x, y)
@@ -52,16 +52,38 @@ case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends
5252

5353
val zero: T = Sta(staRing.zero)
5454
val one: T = Sta(staRing.one)
55-
val add = (x: T, y: T) => (x, y) match {
55+
def add = (x: T, y: T) => (x, y) match {
5656
case (Sta(x), Sta(y)) => Sta(x + y)
5757
case (x, y) => Dyn(dyn(x) + dyn(y))
5858
}
59-
val sub = (x: T, y: T) => (x, y) match {
59+
def sub = (x: T, y: T) => (x, y) match {
6060
case (Sta(x), Sta(y)) => Sta(x - y)
6161
case (x, y) => Dyn(dyn(x) - dyn(y))
6262
}
63-
val mul = (x: T, y: T) => (x, y) match {
63+
def mul = (x: T, y: T) => (x, y) match {
6464
case (Sta(x), Sta(y)) => Sta(x * y)
6565
case (x, y) => Dyn(dyn(x) * dyn(y))
6666
}
6767
}
68+
69+
class RingIntPExpr extends RingPV(RingInt, RingIntExpr)
70+
71+
class RingIntOPCode extends RingIntPExpr {
72+
override def add = (x: PV[Int], y: PV[Int]) => (x, y) match {
73+
case (Sta(0), y) => y
74+
case (x, Sta(0)) => x
75+
case (x, y) => super.add(x, y)
76+
}
77+
override def sub = (x: T, y: T) => (x, y) match {
78+
case (Sta(0), y) => y
79+
case (x, Sta(0)) => x
80+
case (x, y) => super.sub(x, y)
81+
}
82+
override def mul = (x: T, y: T) => (x, y) match {
83+
case (Sta(0), y) => Sta(0)
84+
case (x, Sta(0)) => Sta(0)
85+
case (Sta(1), y) => y
86+
case (x, Sta(1)) => x
87+
case (x, y) => super.mul(x, y)
88+
}
89+
}

tests/run-with-compiler/shonan-hmm/Test.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import scala.quoted._
77
object Test {
88

99
def main(args: Array[String]): Unit = {
10-
println("dafd")
11-
1210
{
1311
val intComplex = new RingComplex(RingInt)
1412
import intComplex._
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import scala.quoted._
2+
import Lifters._
3+
4+
object UnrolledExpr {
5+
6+
implicit class Unrolled[T: Liftable, It <: Iterable[T]](xs: It) {
7+
def unrolled: UnrolledExpr[T, It] = new UnrolledExpr(xs)
8+
}
9+
10+
// TODO support blocks in the compiler to avoid creating trees of blocks?
11+
def block[T](stats: Iterable[Expr[_]], expr: Expr[T]): Expr[T] = {
12+
def rec(stats: List[Expr[_]]): Expr[T] = stats match {
13+
case x :: xs => '{ ~x; ~rec(xs) }
14+
case Nil => expr
15+
}
16+
rec(stats.toList)
17+
}
18+
19+
}
20+
21+
class UnrolledExpr[T: Liftable, It <: Iterable[T]](xs: It) {
22+
import UnrolledExpr._
23+
24+
def foreach[U](f: T => Expr[U]): Expr[Unit] = block(xs.map(f), '())
25+
26+
def withFilter(f: T => Boolean): UnrolledExpr[T, Iterable[T]] = new UnrolledExpr(xs.filter(f))
27+
28+
def foldLeft[U](acc: Expr[U])(f: (Expr[U], T) => Expr[U]): Expr[U] =
29+
xs.foldLeft(acc)((acc, x) => f(acc, x))
30+
}

tests/run-with-compiler/shonan-hmm/VecROp.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ class StaticVecR[T](r: Ring[T]) extends VecSta with VecROp[Int, T, Unit] {
1616
override def toString(): String = s"StaticVecR($r)"
1717
}
1818

19-
class VecRDyn[T: Type](r: Ring[Expr[T]]) extends VecDyn with VecROp[Expr[Int], Expr[T], Expr[Unit]] {
19+
class VecRDyn[T: Type] extends VecDyn with VecROp[Expr[Int], Expr[T], Expr[Unit]] {
2020
def reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = {
2121
(plus, zero, vec) => '{
22-
var sum = ~r.zero
22+
var sum = ~zero
2323
for (i <- 0 until ~vec.size)
2424
sum = ~{ plus('(sum), vec('(i))) }
2525
sum
2626
}
2727
}
28-
override def toString(): String = s"VecRDyn($r)"
28+
override def toString(): String = s"VecRDyn"
2929
}
3030

31-
class VecRStaDim[T: Type](r: Ring[Expr[T]]) extends VecROp[Int, Expr[T], Expr[Unit]] {
32-
val M = new StaticVecR[Expr[T]](r)
33-
def reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Int, Expr[T]]) => Expr[T] = M.reduce
31+
class VecRStaDim[T: Type](r: Ring[T]) extends VecROp[Int, T, Expr[Unit]] {
32+
val M = new StaticVecR[T](r)
33+
def reduce: ((T, T) => T, T, Vec[Int, T]) => T = M.reduce
3434
val seq: (Expr[Unit], Expr[Unit]) => Expr[Unit] = (e1, e2) => '{ ~e1; ~e2 }
3535
// val iter: (arr: Vec[]) = reduce seq .<()>. arr
3636
def iter: Vec[Int, Expr[Unit]] => Expr[Unit] = arr => {
@@ -42,13 +42,13 @@ class VecRStaDim[T: Type](r: Ring[Expr[T]]) extends VecROp[Int, Expr[T], Expr[Un
4242
override def toString(): String = s"VecRStaDim($r)"
4343
}
4444

45-
class VecRStaDyn[T : Type : Liftable](r: Ring[Expr[T]]) extends VecROp[PV[Int], Expr[T], Expr[Unit]] {
46-
val VSta = new VecRStaDim(r)
47-
val VDyn = new VecRDyn(r)
45+
class VecRStaDyn[T : Type : Liftable](r: Ring[PV[T]]) extends VecROp[PV[Int], PV[T], Expr[Unit]] {
46+
val VSta: VecROp[Int, PV[T], Expr[Unit]] = new VecRStaDim(r)
47+
val VDyn = new VecRDyn
4848
val dyn = Dyns.dyn[T]
49-
def reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[PV[Int], Expr[T]]) => Expr[T] = { (plus, zero, vec) => vec match {
49+
def reduce: ((PV[T], PV[T]) => PV[T], PV[T], Vec[PV[Int], PV[T]]) => PV[T] = { (plus, zero, vec) => vec match {
5050
case Vec(Sta(n), v) => VSta.reduce(plus, zero, Vec(n, i => v(Sta(i))))
51-
case Vec(Dyn(n), v) => VDyn.reduce((x, y) => plus(x, y), zero, Vec(n, i => v(Dyn(i))))
51+
case Vec(Dyn(n), v) => Dyn(VDyn.reduce((x, y) => dyn(plus(Dyn(x), Dyn(y))), dyn(zero), Vec(n, i => dyn(v(Dyn(i))))))
5252
}
5353
}
5454
def iter: Vec[PV[Int], Expr[Unit]] => Expr[Unit] = arr => {

0 commit comments

Comments
 (0)