Skip to content

Commit e6f6b07

Browse files
committed
Implement Shonan HMM
1 parent 42698ae commit e6f6b07

File tree

13 files changed

+483
-2
lines changed

13 files changed

+483
-2
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/TastyImpl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,12 @@ class TastyImpl(val rootContext: Contexts.Context) extends scala.tasty.Tasty { s
387387
private def normalizedLoops(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
388388
case block: tpd.Block if block.stats.size > 1 =>
389389
def normalizeInnerLoops(stats: List[tpd.Tree]): List[tpd.Tree] = stats match {
390-
case (x: tpd.DefDef) :: y :: xs if y.symbol.is(Flags.Label) =>
390+
case (x: tpd.DefDef) :: y :: xs if y.symbol.is(Flags.Label) || y.isInstanceOf[tpd.Closure] =>
391391
tpd.Block(x :: Nil, y) :: normalizeInnerLoops(xs)
392392
case x :: xs => x :: normalizeInnerLoops(xs)
393393
case Nil => Nil
394394
}
395-
if (block.expr.symbol.is(Flags.Label)) {
395+
if (block.expr.symbol.is(Flags.Label) || block.expr.isInstanceOf[tpd.Closure]) {
396396
val stats1 = normalizeInnerLoops(block.stats.init)
397397
val normalLoop = tpd.Block(block.stats.last :: Nil, block.expr)
398398
tpd.Block(stats1, normalLoop)

library/src/scala/tasty/util/ShowSourceCode.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ class ShowSourceCode[T <: Tasty with Singleton](tasty0: T) extends Show[T](tasty
373373
expr match {
374374
case Term.Lambda(_, _) =>
375375
// Decompile lambda from { def annon$(...) = ...; closure(annon$, ...)}
376+
assert(stats.size == 1)
376377
val DefDef(_, _, args :: Nil, _, Some(rhs)) :: Nil = stats
377378
inParens {
378379
printArgsDefs(args)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
dafd
2+
Complex(0,10)
3+
Complex(1.*(4).-(2.*(2)), 1.*(2).+(2.*(4)))
4+
List(Complex(2,0), Complex(-4,4), Complex(-2,6))
5+
((vout: scala.Array[Complex[scala.Int]], v1: scala.Array[Complex[scala.Int]], v2: scala.Array[Complex[scala.Int]]) => {
6+
val n: scala.Int = vout.length
7+
scala.Predef.intWrapper(0).until(n).foreach[scala.Unit](((i: scala.Int) => vout.update(i, Complex.apply[scala.Int](v1.apply(i).re.*(v2.apply(i).re).-(v1.apply(i).im.*(v2.apply(i).im)), v1.apply(i).re.*(v2.apply(i).im).+(v1.apply(i).im.*(v2.apply(i).re))))))
8+
})
9+
List(25, 30, 20, 43, 44)
10+
11+
12+
13+
((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => {
14+
val n: scala.Int = vout.length
15+
val m: scala.Int = v.length
16+
scala.Predef.intWrapper(0).until(n).foreach[scala.Unit](((i: scala.Int) => vout.update(i, {
17+
var sum: scala.Int = 0
18+
scala.Predef.intWrapper(0).until(m).foreach[scala.Unit](((i: scala.Int) => sum = sum.+(v.apply(i).*(a.apply(i).apply(i)))))
19+
(sum: scala.Int)
20+
})))
21+
})
22+
23+
24+
25+
((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => {
26+
dotty.DottyPredef.assert(3.==(vout.length).&&(2.==(v.length)))
27+
vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1))))
28+
vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1))))
29+
vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1))))
30+
})
31+
32+
33+
34+
{
35+
val arr: scala.Array[scala.Array[scala.Int]] = scala.Array.apply[scala.Array[scala.Int]](scala.Array.apply(5, 0, 0, 5, 0), scala.Array.apply(0, 0, 10, 0, 0), scala.Array.apply(0, 10, 0, 0, 0), scala.Array.apply(0, 0, 2, 3, 5), scala.Array.apply(0, 0, 3, 0, 7))(scala.reflect.ClassTag.apply[scala.Int](java.lang.Integer.TYPE).wrap)
36+
37+
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
38+
dotty.DottyPredef.assert(5.==(vout.length).&&(5.==(v.length)))
39+
vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0)))
40+
vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
41+
vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
42+
vout.update(3, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(2)).+(v.apply(3).*(3)).+(v.apply(4).*(5)))
43+
vout.update(4, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(3)).+(v.apply(3).*(0)).+(v.apply(4).*(7)))
44+
})
45+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
import scala.quoted._
3+
4+
class Blas1[Idx, T, Unt](tring: Ring[T], vec: VecOp[Idx, Unt]) {
5+
import tring._
6+
import vec._
7+
8+
implicit class Blas1VecOps(v1: Vec[Idx, T]) {
9+
def `*.`(v2: Vec[Idx, T]): Vec[Idx, T] = v1.zipWith(v2, mul)
10+
}
11+
12+
implicit class Blas1OVecOps(vout: OVec[Idx, T, Unt]) {
13+
def :=(vin: Vec[Idx, T]): Unt = iter(vout.vecAssign(vin))
14+
}
15+
override def toString(): String = s"Blas1($tring, $vec)"
16+
}
17+
18+
class Blas2[Idx, T, Unt](tring: Ring[T], vec: VecROp[Idx, T, Unt]) extends Blas1[Idx, T, Unt](tring, vec) {
19+
import tring._
20+
import vec._
21+
22+
implicit class Blas2VecOps(v1: Vec[Idx, T]) {
23+
def dot(v2: Vec[Idx, T]): T = reduce(add, zero, v1 `*.` v2)
24+
}
25+
26+
implicit class Blas2MatOps(a: Vec[Idx, Vec[Idx, T]]) {
27+
def *(v: Vec[Idx, T]): Vec[Idx, T] = a.vecMap(x => v dot x)
28+
}
29+
override def toString(): String = s"Blas2($tring, $vec)"
30+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
import scala.quoted._
3+
4+
case class Complex[T](re: T, im: T)
5+
6+
object Complex {
7+
implicit def complexIsLiftable[T: Type: Liftable]: Liftable[Complex[T]] = new Liftable {
8+
def toExpr(c: Complex[T]): Expr[Complex[T]] = '{ Complex(~c.re.toExpr, ~c.im.toExpr) }
9+
}
10+
11+
def of_complex_expr(x: Expr[Complex[Int]]): Complex[Expr[Int]] = Complex('((~x).re), '((~x).im))
12+
def of_expr_complex(x: Complex[Expr[Int]]): Expr[Complex[Int]] = '(Complex(~x.re, ~x.im))
13+
14+
15+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
2+
import dotty.tools.dotc.quoted.Toolbox._
3+
import scala.quoted._
4+
5+
class MVmult[Idx, T, Unt](tring: Ring[T], vec: VecROp[Idx, T, Unt]) {
6+
private[this] val blas2 = new Blas2(tring, vec)
7+
import blas2._
8+
def mvmult(vout: OVec[Idx, T, Unt], a: Vec[Idx, Vec[Idx, T]], v: Vec[Idx, T]): Unt = vout := a * v
9+
override def toString(): String = s"MVmult($tring, $vec)"
10+
}
11+
12+
object MVmult {
13+
def mvmult_p(vout: Array[Int], a: Array[Array[Int]], v: Array[Int]): Unit = {
14+
val n = vout.length
15+
val m = v.length
16+
17+
val vout_ = OVec(n, (i, x: Int) => vout(i) = x)
18+
val a_ = Vec (n, i => Vec(m, j => a(i)(j)))
19+
val v_ = Vec (n, i => v(i))
20+
21+
val MV = new MVmult[Int, Int, Unit](RingInt, new StaticVecR(RingInt))
22+
MV.mvmult(vout_, a_, v_)
23+
}
24+
25+
def mvmult_c: Expr[(Array[Int], Array[Array[Int]], Array[Int]) => Unit] = '{
26+
(vout, a, v) => {
27+
val n = vout.length
28+
val m = v.length
29+
~{
30+
val vout_ = OVec('(n), (i, x: Expr[Int]) => '(vout(~i) = ~x))
31+
val a_ = Vec('(n), (i: Expr[Int]) => Vec('(m), (j: Expr[Int]) => '{ a(~i)(~j) } ))
32+
val v_ = Vec('(m), (i: Expr[Int]) => '(v(~i)))
33+
34+
val MV = new MVmult[Expr[Int], Expr[Int], Expr[Unit]](RingIntExpr, new VecRDyn(RingIntExpr))
35+
MV.mvmult(vout_, a_, v_)
36+
}
37+
}
38+
}
39+
40+
def mvmult_mc(n: Int, m: Int): Expr[(Array[Int], Array[Array[Int]], Array[Int]) => Unit] = {
41+
val MV = new MVmult[Int, Expr[Int], Expr[Unit]](RingIntExpr, new VecRStaDim(RingIntExpr))
42+
'{
43+
(vout, a, v) => {
44+
assert (~n.toExpr == vout.length && ~m.toExpr == v.length)
45+
~{
46+
val vout_ = OVec(n, (i, x: Expr[Int]) => '(vout(~i.toExpr) = ~x))
47+
val a_ = Vec(n, i => Vec(m, j => '{ a(~i.toExpr)(~j.toExpr) } ))
48+
val v_ = Vec(m, i => '(v(~i.toExpr)))
49+
50+
MV.mvmult(vout_, a_, v_)
51+
}
52+
}
53+
}
54+
}
55+
56+
def mvmult_ac(a: Array[Array[Int]]): Expr[(Array[Int], Array[Int]) => Unit] = {
57+
val n = a.length
58+
val m = a(0).length
59+
60+
// Array lifters
61+
62+
63+
'{
64+
val arr = Array( // FIXMR lift a
65+
Array( 5, 0, 0, 5, 0),
66+
Array( 0, 0, 10, 0, 0),
67+
Array( 0, 10, 0, 0, 0),
68+
Array( 0, 0, 2, 3, 5),
69+
Array( 0, 0, 3, 0, 7)
70+
)
71+
(vout, v) => {
72+
assert (~n.toExpr == vout.length && ~m.toExpr == v.length)
73+
~{
74+
val vout_ : OVec[PV[Int], Expr[Int], Expr[Unit]] = OVec(Sta(n), (i, x) => '(vout(~Dyns.dyni(i)) = ~x))
75+
val a2: Vec[PV[Int], Vec[PV[Int], Expr[Int]]] = Vec(Sta(n), i => Vec(Sta(m), j => Dyns.dyn((i, j) match {
76+
case (Sta(i), Sta(j)) => Sta(a(i)(j))
77+
case (Sta(i), Dyn(j)) => Dyn('(arr(~i.toExpr)(~j)))
78+
case (i, j) => Dyn('{ arr(~(Dyns.dyni(i)))(~(Dyns.dyni(j))) })
79+
})))
80+
val v_ : Vec[PV[Int], Expr[Int]] = Vec(Sta(m), i => '(v(~Dyns.dyni(i))))
81+
val MV = new MVmult[PV[Int], Expr[Int], Expr[Unit]](RingIntExpr, new VecRStaDyn(RingIntExpr))
82+
MV.mvmult(vout_, a2, v_)
83+
}
84+
}
85+
}
86+
}
87+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
import scala.quoted._
3+
4+
sealed trait PV[T]
5+
6+
case class Sta[T](x: T) extends PV[T]
7+
8+
case class Dyn[T](x: Expr[T]) extends PV[T]
9+
10+
object Dyns {
11+
def dyn[T: Liftable](pv: PV[T]): Expr[T] = pv match {
12+
case Sta(x) => x.toExpr
13+
case Dyn(x) => x
14+
}
15+
val dyni: PV[Int] => Expr[Int] = dyn[Int]
16+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
2+
import scala.quoted._
3+
4+
trait Ring[T] {
5+
val zero: T
6+
val one: T
7+
val add: (x: T, y: T) => T
8+
val sub: (x: T, y: T) => T
9+
val mul: (x: T, y: T) => T
10+
11+
implicit class Ops(x: T) {
12+
def +(y: T): T = add(x, y)
13+
def -(y: T): T = sub(x, y)
14+
def *(y: T): T = mul(x, y)
15+
}
16+
}
17+
18+
object RingInt extends Ring[Int] {
19+
val zero = 0
20+
val one = 0
21+
val add = (x, y) => x + y
22+
val sub = (x, y) => x - y
23+
val mul = (x, y) => x * y
24+
override def toString(): String = "RingInt"
25+
}
26+
27+
object RingIntExpr extends Ring[Expr[Int]] {
28+
val zero = '(0)
29+
val one = '(1)
30+
val add = (x, y) => '(~x + ~y)
31+
val sub = (x, y) => '(~x - ~y)
32+
val mul = (x, y) => '(~x * ~y)
33+
override def toString(): String = "RingIntExpr"
34+
}
35+
36+
case class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] {
37+
import u._
38+
val zero = Complex(u.zero, u.zero)
39+
val one = Complex(u.one, u.zero)
40+
val add = (x, y) => Complex(x.re + y.re, x.im + y.im)
41+
val sub = (x, y) => Complex(x.re + y.re, x.im + y.im)
42+
val mul = (x, y) => Complex(x.re * y.re - x.im * y.im, x.re * y.im + x.im * y.re)
43+
override def toString(): String = s"RingComplex($u)"
44+
}
45+
46+
case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends Ring[PV[U]] {
47+
type T = PV[U]
48+
49+
val dyn = Dyns.dyn[U]
50+
import staRing._
51+
import dynRing._
52+
53+
val zero: T = Sta(staRing.zero)
54+
val one: T = Sta(staRing.one)
55+
val add = (x: T, y: T) => (x, y) match {
56+
case (Sta(x), Sta(y)) => Sta(x + y)
57+
case (x, y) => Dyn(dyn(x) + dyn(y))
58+
}
59+
val sub = (x: T, y: T) => (x, y) match {
60+
case (Sta(x), Sta(y)) => Sta(x - y)
61+
case (x, y) => Dyn(dyn(x) - dyn(y))
62+
}
63+
val mul = (x: T, y: T) => (x, y) match {
64+
case (Sta(x), Sta(y)) => Sta(x * y)
65+
case (x, y) => Dyn(dyn(x) * dyn(y))
66+
}
67+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
import dotty.tools.dotc.quoted.Toolbox._
3+
import scala.quoted._
4+
5+
// DYNAMIC
6+
7+
object Test {
8+
9+
def main(args: Array[String]): Unit = {
10+
println("dafd")
11+
12+
{
13+
val intComplex = new RingComplex(RingInt)
14+
import intComplex._
15+
16+
println(Complex(1, 2) * Complex(4, 2))
17+
}
18+
19+
{
20+
val intExprComplex = new RingComplex(RingIntExpr)
21+
import intExprComplex._
22+
23+
val res = Complex('(1), '(2)) * Complex('(4), '(2))
24+
println(s"Complex(${res.re.show}, ${res.im.show})")
25+
}
26+
27+
// {
28+
// val intExprComplex = implicitly[Ring[Expr[Complex[Int]]]]
29+
// import intExprComplex._
30+
31+
// val res = '(Complex(1, 2)) * '(Complex(4, 2))
32+
// println(res.show)
33+
// }
34+
35+
val arr1 = Array(Complex(1, 0), Complex(0, 4), Complex(2, 2))
36+
val arr2 = Array(Complex(2, 0), Complex(1, 1), Complex(1, 2))
37+
val out = Array(Complex(0, 0), Complex(0, 0), Complex(0, 0))
38+
Vmults.vmult(out, arr1, arr2)
39+
println(out.toList)
40+
41+
println(Vmults.vmultCA.show)
42+
43+
val a = Array(
44+
Array( 5, 0, 0, 5, 0),
45+
Array( 0, 0, 10, 0, 0),
46+
Array( 0, 10, 0, 0, 0),
47+
Array( 0, 0, 2, 3, 5),
48+
Array( 0, 0, 3, 0, 7)
49+
)
50+
51+
val v1 = Array(1, 2, 3, 4, 5)
52+
val v1out = Array(0, 0, 0, 0, 0)
53+
MVmult.mvmult_p(v1out, a, v1)
54+
println(v1out.toList)
55+
println()
56+
println()
57+
println()
58+
59+
println(MVmult.mvmult_c.show)
60+
println()
61+
println()
62+
println()
63+
64+
println(MVmult.mvmult_mc(3, 2).show)
65+
println()
66+
println()
67+
println()
68+
69+
println(MVmult.mvmult_ac(a).show)
70+
71+
}
72+
}
73+
74+
75+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import scala.quoted._
2+
3+
case class Vec[Idx, T](size: Idx, get: Idx => T) {
4+
def apply(idx: Idx): T = get(idx)
5+
6+
def vecMap[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i)))
7+
8+
def zipWith[U, V](vec2: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] =
9+
Vec(size, i => f(get(i), vec2(i)))
10+
}
11+
12+
case class OVec[Idx, T, Unt](size: Idx, update: (Idx, T) => Unt) {
13+
def vecAssign(vecIn: Vec[Idx, T]): Vec[Idx, Unt] =
14+
Vec(vecIn.size, i => update(i, vecIn(i)))
15+
}
16+
17+
object Vec {
18+
def fromArray[T](a: Array[T]): (Vec[Int, T], OVec[Int, T, Unit]) =
19+
(Vec(a.size, i => a(i)), OVec(a.size, (i, v) => a(i) = v))
20+
}

0 commit comments

Comments
 (0)