|
| 1 | +import scala.quoted._ |
| 2 | + |
| 3 | +trait Ring[T] { |
| 4 | + val zero: T |
| 5 | + val one: T |
| 6 | + val add: (x: T, y: T) => T |
| 7 | + val sub: (x: T, y: T) => T |
| 8 | + val mul: (x: T, y: T) => T |
| 9 | +} |
| 10 | + |
| 11 | +class RingInt extends Ring[Int] { |
| 12 | + val zero = 0 |
| 13 | + val one = 1 |
| 14 | + val add = (x, y) => x + y |
| 15 | + val sub = (x, y) => x - y |
| 16 | + val mul = (x, y) => x * y |
| 17 | +} |
| 18 | + |
| 19 | +class RingIntExpr extends Ring[Expr[Int]] { |
| 20 | + val zero = '(0) |
| 21 | + val one = '(1) |
| 22 | + val add = (x, y) => '(~x + ~y) |
| 23 | + val sub = (x, y) => '(~x - ~y) |
| 24 | + val mul = (x, y) => '(~x * ~y) |
| 25 | +} |
| 26 | + |
| 27 | +class RingComplex[U](u: Ring[U]) extends Ring[Complex[U]] { |
| 28 | + val zero = Complex(u.zero, u.zero) |
| 29 | + val one = Complex(u.one, u.zero) |
| 30 | + val add = (x, y) => Complex(u.add(x.re, y.re), u.add(x.im, y.im)) |
| 31 | + val sub = (x, y) => Complex(u.sub(x.re, y.re), u.sub(x.im, y.im)) |
| 32 | + val mul = (x, y) => Complex(u.sub(u.mul(x.re, y.re), u.mul(x.im, y.im)), u.add(u.mul(x.re, y.im), u.mul(x.im, y.re))) |
| 33 | +} |
| 34 | + |
| 35 | +sealed trait PV[T] { |
| 36 | + def expr(implicit l: Liftable[T]): Expr[T] |
| 37 | +} |
| 38 | +case class Sta[T](x: T) extends PV[T] { |
| 39 | + def expr(implicit l: Liftable[T]): Expr[T] = x.toExpr |
| 40 | +} |
| 41 | +case class Dyn[T](x: Expr[T]) extends PV[T] { |
| 42 | + def expr(implicit l: Liftable[T]): Expr[T] = x |
| 43 | +} |
| 44 | + |
| 45 | +case class RingPV[U: Liftable](staRing: Ring[U], dynRing: Ring[Expr[U]]) extends Ring[PV[U]] { |
| 46 | + type T = PV[U] |
| 47 | + |
| 48 | + import staRing._ |
| 49 | + import dynRing._ |
| 50 | + |
| 51 | + val zero: T = Sta(staRing.zero) |
| 52 | + val one: T = Sta(staRing.one) |
| 53 | + val add = (x: T, y: T) => (x, y) match { |
| 54 | + case (Sta(staRing.zero), x) => x |
| 55 | + case (x, Sta(staRing.zero)) => x |
| 56 | + case (Sta(x), Sta(y)) => Sta(staRing.add(x, y)) |
| 57 | + case (x, y) => Dyn(dynRing.add(x.expr, y.expr)) |
| 58 | + } |
| 59 | + val sub = (x: T, y: T) => (x, y) match { |
| 60 | + case (Sta(staRing.zero), x) => x |
| 61 | + case (x, Sta(staRing.zero)) => x |
| 62 | + case (Sta(x), Sta(y)) => Sta(staRing.sub(x, y)) |
| 63 | + case (x, y) => Dyn(dynRing.sub(x.expr, y.expr)) |
| 64 | + } |
| 65 | + val mul = (x: T, y: T) => (x, y) match { |
| 66 | + case (Sta(staRing.zero), _) => Sta(staRing.zero) |
| 67 | + case (_, Sta(staRing.zero)) => Sta(staRing.zero) |
| 68 | + case (Sta(staRing.one), x) => x |
| 69 | + case (x, Sta(staRing.one)) => x |
| 70 | + case (Sta(x), Sta(y)) => Sta(staRing.mul(x, y)) |
| 71 | + case (x, y) => Dyn(dynRing.mul(x.expr, y.expr)) |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +case class Complex[T](re: T, im: T) |
| 76 | + |
| 77 | +case class Vec[Idx, T](size: Idx, get: Idx => T) { |
| 78 | + def map[U](f: T => U): Vec[Idx, U] = Vec(size, i => f(get(i))) |
| 79 | + def zipWith[U, V](other: Vec[Idx, U], f: (T, U) => V): Vec[Idx, V] = Vec(size, i => f(get(i), other.get(i))) |
| 80 | +} |
| 81 | + |
| 82 | + |
| 83 | +trait VecOps[Idx, T] { |
| 84 | + val reduce: ((T, T) => T, T, Vec[Idx, T]) => T |
| 85 | +} |
| 86 | + |
| 87 | +class StaticVecOps[T] extends VecOps[Int, T] { |
| 88 | + val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => { |
| 89 | + var sum = zero |
| 90 | + for (i <- 0 until vec.size) |
| 91 | + sum = plus(sum, vec.get(i)) |
| 92 | + sum |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +class StaticVecOptOps[T] extends VecOps[Int, T] { |
| 97 | + val reduce: ((T, T) => T, T, Vec[Int, T]) => T = (plus, zero, vec) => { |
| 98 | + var sum = zero |
| 99 | + for (i <- 0 until vec.size) |
| 100 | + sum = plus(sum, vec.get(i)) |
| 101 | + sum |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +class ExprVecOps[T: Type] extends VecOps[Expr[Int], Expr[T]] { |
| 106 | + val reduce: ((Expr[T], Expr[T]) => Expr[T], Expr[T], Vec[Expr[Int], Expr[T]]) => Expr[T] = (plus, zero, vec) => '{ |
| 107 | + var sum = ~zero |
| 108 | + for (i <- 0 until ~vec.size) |
| 109 | + sum = ~{ plus('(sum), vec.get('(i))) } |
| 110 | + sum |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +class Blas1[Idx, T](r: Ring[T], ops: VecOps[Idx, T]) { |
| 115 | + def dot(v1: Vec[Idx, T], v2: Vec[Idx, T]): T = ops.reduce(r.add, r.zero, v1.zipWith(v2, r.mul)) |
| 116 | +} |
| 117 | + |
| 118 | +object Test { |
| 119 | + |
| 120 | + implicit val toolbox: scala.quoted.Toolbox = dotty.tools.dotc.quoted.Toolbox.make |
| 121 | + |
| 122 | + def main(args: Array[String]): Unit = { |
| 123 | + val arr1 = Array(1, 2, 4, 8, 16) |
| 124 | + val arr2 = Array(1, 0, 1, 0, 1) |
| 125 | + |
| 126 | + val vec1 = new Vec(arr1.size, i => arr1(i)) |
| 127 | + val vec2 = new Vec(arr2.size, i => arr2(i)) |
| 128 | + val blasInt = new Blas1(new RingInt, new StaticVecOps) |
| 129 | + println(blasInt.dot(vec1, vec2)) |
| 130 | + |
| 131 | + val vec3 = new Vec(arr1.size, i => Complex(2, arr2(i))) |
| 132 | + val vec4 = new Vec(arr2.size, i => Complex(1, arr2(i))) |
| 133 | + val blasComplexInt = new Blas1(new RingComplex(new RingInt), new StaticVecOps) |
| 134 | + println(blasComplexInt.dot(vec3, vec4)) |
| 135 | + |
| 136 | + val vec5 = new Vec(5, i => arr1(i).toExpr) |
| 137 | + val vec6 = new Vec(5, i => arr2(i).toExpr) |
| 138 | + val blasStaticIntExpr = new Blas1(new RingIntExpr, new StaticVecOps) |
| 139 | + println(blasStaticIntExpr.dot(vec5, vec6).show) |
| 140 | + |
| 141 | + |
| 142 | + |
| 143 | + val code = '{ |
| 144 | + val arr3 = Array(1, 2, 4, 8, 16) |
| 145 | + val arr4 = Array(1, 0, 1, 0, 1) |
| 146 | + ~{ |
| 147 | + val vec7 = new Vec('(arr3.size), i => '(arr3(~i))) |
| 148 | + val vec8 = new Vec('(arr4.size), i => '(arr4(~i))) |
| 149 | + val blasExprIntExpr = new Blas1(new RingIntExpr, new ExprVecOps) |
| 150 | + blasExprIntExpr.dot(vec7, vec8) |
| 151 | + } |
| 152 | + |
| 153 | + } |
| 154 | + println(code.show) |
| 155 | + |
| 156 | + |
| 157 | + { |
| 158 | + val vec5 = new Vec[Int, PV[Int]](5, i => Dyn(arr1(i).toExpr)) |
| 159 | + val vec6 = new Vec[Int, PV[Int]](5, i => Sta(arr2(i))) |
| 160 | + val blasStaticIntExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps) |
| 161 | + blasStaticIntExpr.dot(vec5, vec6).expr.show |
| 162 | + } |
| 163 | + |
| 164 | + { |
| 165 | + val code = '{ |
| 166 | + val arr3 = Array(1, 2, 4, 8, 16) |
| 167 | + ~{ |
| 168 | + val vec7 = new Vec[Int, PV[Int]](5, i => Dyn('(arr3(~i.toExpr)))) |
| 169 | + val vec8 = new Vec[Int, PV[Int]](5, i => Sta(arr2(i))) |
| 170 | + val blasExprIntExpr = new Blas1(new RingPV[Int](new RingInt, new RingIntExpr), new StaticVecOps) |
| 171 | + blasExprIntExpr.dot(vec7, vec8).expr |
| 172 | + } |
| 173 | + |
| 174 | + } |
| 175 | + println(code.show) |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | +} |
0 commit comments