|
1 | 1 | package mathParser
|
2 | 2 | package algebra
|
3 | 3 |
|
| 4 | +import mathParser.number.{NumberBinaryOperator, NumberUnitaryOperator} |
| 5 | +import mathParser.number.NumberBinaryOperator.* |
| 6 | +import mathParser.number.NumberUnitaryOperator.* |
| 7 | +import mathParser.number.NumberSyntax.* |
| 8 | +import mathParser.AbstractSyntaxTree.* |
4 | 9 | import spire.algebra.{Field, NRoot, Trig}
|
5 | 10 |
|
6 | 11 | import scala.util.Try
|
7 |
| -import mathParser.AbstractSyntaxTree.* |
| 12 | +import mathParser.{AbstractSyntaxTree, Evaluate} |
8 | 13 |
|
9 | 14 | object SpireLanguage {
|
10 |
| - |
11 |
| - import syntax._ |
12 |
| - |
13 |
| - def apply[A: Field: NRoot: Trig]: SpireLanguage[A, Nothing] = |
| 15 | + def apply[A: Field: NRoot: Trig]: Language[NumberUnitaryOperator, NumberBinaryOperator, A, Nothing] = |
14 | 16 | Language.emptyLanguage
|
15 | 17 | .withConstants[A](List("e" -> Trig[A].e, "pi" -> Trig[A].pi))
|
16 |
| - .withBinaryOperators[SpireBinaryOperator](prefix = List.empty, infix = List(Plus, Minus, Times, Divided, Power).map(op => (op.name, op))) |
| 18 | + .withBinaryOperators[NumberBinaryOperator](prefix = List.empty, infix = List(Plus, Minus, Times, Divided, Power).map(op => (op.name, op))) |
17 | 19 | .withUnitaryOperators(List(Neg, Sin, Cos, Tan, Asin, Acos, Atan, Sinh, Cosh, Tanh, Exp, Log).map(op => (op.name, op)))
|
18 | 20 |
|
19 |
| - def spireLiteralParser[A: Field]: LiteralParser[A] = s => Try(Field[A].fromDouble(s.toDouble)).toOption |
20 |
| - |
21 |
| - def spireEvaluate[A: Field: NRoot: Trig, V]: Evaluate[SpireUnitaryOperator, SpireBinaryOperator, A, V] = |
22 |
| - new Evaluate[SpireUnitaryOperator, SpireBinaryOperator, A, V] { |
23 |
| - def executeUnitary(uo: SpireUnitaryOperator, s: A): A = uo match { |
24 |
| - case Neg => Field[A].negate(s) |
25 |
| - case Sin => Trig[A].sin(s) |
26 |
| - case Cos => Trig[A].cos(s) |
27 |
| - case Tan => Trig[A].tan(s) |
28 |
| - case Asin => Trig[A].asin(s) |
29 |
| - case Acos => Trig[A].acos(s) |
30 |
| - case Atan => Trig[A].atan(s) |
31 |
| - case Sinh => Trig[A].sinh(s) |
32 |
| - case Cosh => Trig[A].cosh(s) |
33 |
| - case Tanh => Trig[A].tanh(s) |
34 |
| - case Exp => Trig[A].exp(s) |
35 |
| - case Log => Trig[A].log(s) |
36 |
| - } |
37 |
| - |
38 |
| - def executeBinaryOperator(bo: SpireBinaryOperator, left: A, right: A): A = bo match { |
39 |
| - case Plus => Field[A].plus(left, right) |
40 |
| - case Minus => Field[A].minus(left, right) |
41 |
| - case Times => Field[A].times(left, right) |
42 |
| - case Divided => Field[A].div(left, right) |
43 |
| - case Power => NRoot[A].fpow(left, right) |
44 |
| - } |
45 |
| - } |
46 |
| - |
47 |
| - def spireOptimizer[A: Field: NRoot: Trig, V]: Optimizer[SpireUnitaryOperator, SpireBinaryOperator, A, V] = |
48 |
| - new Optimizer[SpireUnitaryOperator, SpireBinaryOperator, A, V] { |
49 |
| - override def rules: List[PartialFunction[SpireNode[A, V], SpireNode[A, V]]] = List( |
50 |
| - Optimize.replaceConstantsRule(using spireEvaluate), { |
51 |
| - case UnitaryNode(Neg, UnitaryNode(Neg, child)) => child |
52 |
| - case BinaryNode(Plus, left, ConstantNode(0d)) => left |
53 |
| - case BinaryNode(Plus, ConstantNode(0d), right) => right |
54 |
| - case BinaryNode(Times, ConstantNode(0d), _) => zero |
55 |
| - case BinaryNode(Times, _, ConstantNode(0d)) => zero |
56 |
| - case BinaryNode(Times, left, ConstantNode(1d)) => left |
57 |
| - case BinaryNode(Times, ConstantNode(1d), right) => right |
58 |
| - case BinaryNode(Power, left, ConstantNode(1d)) => left |
59 |
| - case BinaryNode(Power, _, ConstantNode(0d)) => one[A, V] |
60 |
| - case BinaryNode(Power, ConstantNode(1d), _) => one[A, V] |
61 |
| - case BinaryNode(Power, ConstantNode(0d), _) => zero[A, V] |
62 |
| - case UnitaryNode(Log, UnitaryNode(Exp, child)) => child |
63 |
| - case BinaryNode(Plus, left, UnitaryNode(Neg, child)) => left - child |
64 |
| - case BinaryNode(Minus, left, UnitaryNode(Neg, child)) => left + child |
65 |
| - case BinaryNode(Minus, left, right) if left == right => zero |
66 |
| - case BinaryNode(Divided, left, right) if left == right => one |
67 |
| - } |
68 |
| - ) |
69 |
| - } |
70 |
| - |
71 |
| - def spireDerive[A: Field: Trig: NRoot, V]: Derive[SpireUnitaryOperator, SpireBinaryOperator, A, V] = |
72 |
| - new Derive[SpireUnitaryOperator, SpireBinaryOperator, A, V] { |
73 |
| - def derive(term: SpireNode[A, V])(variable: V): SpireNode[A, V] = { |
74 |
| - def derive(term: SpireNode[A, V]): SpireNode[A, V] = term match { |
75 |
| - case VariableNode(`variable`) => one |
76 |
| - case VariableNode(_) | ConstantNode(_) => zero |
77 |
| - case UnitaryNode(op, f) => |
78 |
| - op match { |
79 |
| - case Neg => neg(derive(f)) |
80 |
| - case Sin => derive(f) * cos(f) |
81 |
| - case Cos => neg(derive(f) * sin(f)) |
82 |
| - case Tan => derive(f) / (cos(f) * cos(f)) |
83 |
| - case Asin => derive(f) / sqrt(one - (f * f)) |
84 |
| - case Acos => neg(derive(f)) / sqrt(one - (f * f)) |
85 |
| - case Atan => derive(f) / (one + (f * f)) |
86 |
| - case Sinh => derive(f) * cosh(f) |
87 |
| - case Cosh => derive(f) * sinh(f) |
88 |
| - case Tanh => derive(f) / (cosh(f) * cosh(f)) |
89 |
| - case Exp => exp(f) * derive(f) |
90 |
| - case Log => derive(f) / f |
91 |
| - } |
92 |
| - case BinaryNode(op, f, g) => |
93 |
| - op match { |
94 |
| - case Plus => derive(f) + derive(g) |
95 |
| - case Minus => derive(f) - derive(g) |
96 |
| - case Times => (derive(f) * g) + (derive(g) * f) |
97 |
| - case Divided => ((f * derive(g)) - (g * derive(f))) / (g * g) |
98 |
| - case Power => (f ^ (g - one)) * ((g * derive(f)) + (f * log(f) * derive(g))) |
99 |
| - } |
100 |
| - } |
| 21 | + given[A] (using field: Field[A]): mathParser.number.Number[A] = mathParser.number.Number.contraMap(field.fromDouble) |
101 | 22 |
|
102 |
| - derive(term) |
103 |
| - } |
104 |
| - } |
| 23 | + given spireLiteralParser[A: Field]: LiteralParser[A] = s => s.toDoubleOption.map(Field[A].fromDouble) |
105 | 24 |
|
106 |
| - object syntax { |
107 |
| - def neg[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Neg, t) |
108 |
| - def sin[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Sin, t) |
109 |
| - def cos[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Cos, t) |
110 |
| - def tan[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Tan, t) |
111 |
| - def asin[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Asin, t) |
112 |
| - def acos[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Acos, t) |
113 |
| - def atan[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Atan, t) |
114 |
| - def sinh[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Sinh, t) |
115 |
| - def cosh[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Cosh, t) |
116 |
| - def tanh[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Tanh, t) |
117 |
| - def exp[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Exp, t) |
118 |
| - def log[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = UnitaryNode(Log, t) |
| 25 | + given [A: Field: NRoot: Trig, V]: Evaluate[NumberUnitaryOperator, NumberBinaryOperator, A, V] = SpireEvaluate() |
119 | 26 |
|
120 |
| - def sqrt[A: Field: Trig: NRoot, V](t: SpireNode[A, V]): SpireNode[A, V] = BinaryNode(Power, t, ConstantNode(Field[A].fromDouble(0.5))) |
| 27 | + given [A: Field: NRoot: Trig, V]: Optimizer[NumberUnitaryOperator, NumberBinaryOperator, A, V] = |
| 28 | + mathParser.number.NumberOptimizer() |
121 | 29 |
|
122 |
| - extension [A: Field: Trig: NRoot, V](t1: SpireNode[A, V]) { |
123 |
| - def +(t2: SpireNode[A, V]): SpireNode[A, V] = BinaryNode(Plus, t1, t2) |
124 |
| - def -(t2: SpireNode[A, V]): SpireNode[A, V] = BinaryNode(Minus, t1, t2) |
125 |
| - def *(t2: SpireNode[A, V]): SpireNode[A, V] = BinaryNode(Times, t1, t2) |
126 |
| - def /(t2: SpireNode[A, V]): SpireNode[A, V] = BinaryNode(Divided, t1, t2) |
127 |
| - def ^(t2: SpireNode[A, V]): SpireNode[A, V] = BinaryNode(Power, t1, t2) |
128 |
| - } |
| 30 | + given [A:Field:NRoot:Trig, V]: Derive[NumberUnitaryOperator, NumberBinaryOperator, A, V] = |
| 31 | + mathParser.number.NumberDerive() |
129 | 32 |
|
130 |
| - def zero[A: Field: Trig: NRoot, V]: SpireNode[A, V] = ConstantNode(Field[A].zero) |
131 |
| - def one[A: Field: Trig: NRoot, V]: SpireNode[A, V] = ConstantNode(Field[A].one) |
132 |
| - def two[A: Field: Trig: NRoot, V]: SpireNode[A, V] = ConstantNode(Field[A].fromInt(2)) |
133 |
| - } |
134 | 33 | }
|
0 commit comments