42
42
import numpy as np
43
43
import pytensor .tensor as pt
44
44
45
+ from pytensor import scan
45
46
from pytensor .gradient import DisconnectedType , jacobian
46
47
from pytensor .graph .basic import Apply , Node , Variable
47
48
from pytensor .graph .features import AlreadyThere , Feature
48
49
from pytensor .graph .fg import FunctionGraph
49
50
from pytensor .graph .op import Op
50
51
from pytensor .graph .replace import clone_replace
51
52
from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
52
- from pytensor .scalar import Abs , Add , Exp , Log , Mul , Pow , Sqr , Sqrt
53
+ from pytensor .scalar import (
54
+ Abs ,
55
+ Add ,
56
+ Cosh ,
57
+ Erf ,
58
+ Erfc ,
59
+ Erfcx ,
60
+ Exp ,
61
+ Log ,
62
+ Mul ,
63
+ Pow ,
64
+ Sinh ,
65
+ Sqr ,
66
+ Sqrt ,
67
+ Tanh ,
68
+ )
53
69
from pytensor .scan .op import Scan
54
70
from pytensor .tensor .exceptions import NotScalarConstantError
55
71
from pytensor .tensor .math import (
56
72
abs ,
57
73
add ,
74
+ cosh ,
75
+ erf ,
76
+ erfc ,
77
+ erfcx ,
58
78
exp ,
59
79
log ,
60
80
mul ,
61
81
neg ,
62
82
pow ,
63
83
reciprocal ,
84
+ sinh ,
64
85
sqr ,
65
86
sqrt ,
66
87
sub ,
88
+ tanh ,
67
89
true_div ,
68
90
)
69
91
from pytensor .tensor .rewriting .basic import (
@@ -122,6 +144,8 @@ def remove_TransformedVariables(fgraph, node):
122
144
123
145
124
146
class RVTransform (abc .ABC ):
147
+ ndim_supp = None
148
+
125
149
@abc .abstractmethod
126
150
def forward (self , value : TensorVariable , * inputs : Variable ) -> TensorVariable :
127
151
"""Apply the transformation."""
@@ -135,12 +159,16 @@ def backward(
135
159
136
160
def log_jac_det (self , value : TensorVariable , * inputs ) -> TensorVariable :
137
161
"""Construct the log of the absolute value of the Jacobian determinant."""
138
- # jac = pt.reshape(
139
- # gradient(pt.sum(self.backward(value, *inputs)), [value]), value.shape
140
- # )
141
- # return pt.log(pt.abs(jac))
142
- phi_inv = self .backward (value , * inputs )
143
- return pt .log (pt .abs (pt .nlinalg .det (pt .atleast_2d (jacobian (phi_inv , [value ])[0 ]))))
162
+ if self .ndim_supp not in (0 , 1 ):
163
+ raise NotImplementedError (
164
+ f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got { self .ndim_supp = } "
165
+ )
166
+ if self .ndim_supp == 0 :
167
+ jac = pt .reshape (pt .grad (pt .sum (self .backward (value , * inputs )), [value ]), value .shape )
168
+ return pt .log (pt .abs (jac ))
169
+ else :
170
+ phi_inv = self .backward (value , * inputs )
171
+ return pt .log (pt .abs (pt .nlinalg .det (pt .atleast_2d (jacobian (phi_inv , [value ])[0 ]))))
144
172
145
173
146
174
@node_rewriter (tracks = None )
@@ -340,7 +368,7 @@ def apply(self, fgraph: FunctionGraph):
340
368
class MeasurableTransform (MeasurableElemwise ):
341
369
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
342
370
343
- valid_scalar_types = (Exp , Log , Add , Mul , Pow , Abs )
371
+ valid_scalar_types = (Exp , Log , Add , Mul , Pow , Abs , Sinh , Cosh , Tanh , Erf , Erfc , Erfcx )
344
372
345
373
# Cannot use `transform` as name because it would clash with the property added by
346
374
# the `TransformValuesRewrite`
@@ -540,7 +568,7 @@ def measurable_sub_to_neg(fgraph, node):
540
568
return [pt .add (minuend , pt .neg (subtrahend ))]
541
569
542
570
543
- @node_rewriter ([exp , log , add , mul , pow , abs ])
571
+ @node_rewriter ([exp , log , add , mul , pow , abs , sinh , cosh , tanh , erf , erfc , erfcx ])
544
572
def find_measurable_transforms (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
545
573
"""Find measurable transformations from Elemwise operators."""
546
574
@@ -585,13 +613,20 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
585
613
measurable_input_idx = 0
586
614
transform_inputs : Tuple [TensorVariable , ...] = (measurable_input ,)
587
615
transform : RVTransform
588
- if isinstance (scalar_op , Exp ):
589
- transform = ExpTransform ()
590
- elif isinstance (scalar_op , Log ):
591
- transform = LogTransform ()
592
- elif isinstance (scalar_op , Abs ):
593
- transform = AbsTransform ()
594
- elif isinstance (scalar_op , Pow ):
616
+
617
+ transform_dict = {
618
+ Exp : ExpTransform (),
619
+ Log : LogTransform (),
620
+ Abs : AbsTransform (),
621
+ Sinh : SinhTransform (),
622
+ Cosh : CoshTransform (),
623
+ Tanh : TanhTransform (),
624
+ Erf : ErfTransform (),
625
+ Erfc : ErfcTransform (),
626
+ Erfcx : ErfcxTransform (),
627
+ }
628
+ transform = transform_dict .get (type (scalar_op ), None )
629
+ if isinstance (scalar_op , Pow ):
595
630
# We only allow for the base to be measurable
596
631
if measurable_input_idx != 0 :
597
632
return None
@@ -608,7 +643,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
608
643
transform = LocTransform (
609
644
transform_args_fn = lambda * inputs : inputs [- 1 ],
610
645
)
611
- else :
646
+ elif transform is None :
612
647
transform_inputs = (measurable_input , pt .mul (* other_inputs ))
613
648
transform = ScaleTransform (
614
649
transform_args_fn = lambda * inputs : inputs [- 1 ],
@@ -671,6 +706,87 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
671
706
)
672
707
673
708
709
+ class SinhTransform (RVTransform ):
710
+ name = "sinh"
711
+ ndim_supp = 0
712
+
713
+ def forward (self , value , * inputs ):
714
+ return pt .sinh (value )
715
+
716
+ def backward (self , value , * inputs ):
717
+ return pt .arcsinh (value )
718
+
719
+
720
+ class CoshTransform (RVTransform ):
721
+ name = "cosh"
722
+ ndim_supp = 0
723
+
724
+ def forward (self , value , * inputs ):
725
+ return pt .cosh (value )
726
+
727
+ def backward (self , value , * inputs ):
728
+ return pt .arccosh (value )
729
+
730
+
731
+ class TanhTransform (RVTransform ):
732
+ name = "tanh"
733
+ ndim_supp = 0
734
+
735
+ def forward (self , value , * inputs ):
736
+ return pt .tanh (value )
737
+
738
+ def backward (self , value , * inputs ):
739
+ return pt .arctanh (value )
740
+
741
+
742
+ class ErfTransform (RVTransform ):
743
+ name = "erf"
744
+ ndim_supp = 0
745
+
746
+ def forward (self , value , * inputs ):
747
+ return pt .erf (value )
748
+
749
+ def backward (self , value , * inputs ):
750
+ return pt .erfinv (value )
751
+
752
+
753
+ class ErfcTransform (RVTransform ):
754
+ name = "erfc"
755
+ ndim_supp = 0
756
+
757
+ def forward (self , value , * inputs ):
758
+ return pt .erfc (value )
759
+
760
+ def backward (self , value , * inputs ):
761
+ return pt .erfcinv (value )
762
+
763
+
764
+ class ErfcxTransform (RVTransform ):
765
+ name = "erfcx"
766
+ ndim_supp = 0
767
+
768
+ def forward (self , value , * inputs ):
769
+ return pt .erfcx (value )
770
+
771
+ def backward (self , value , * inputs ):
772
+ # computes the inverse of erfcx, this was adapted from
773
+ # https://tinyurl.com/4mxfd3cz
774
+ x = pt .switch (value <= 1 , 1.0 / (value * pt .sqrt (np .pi )), - pt .sqrt (pt .log (value )))
775
+
776
+ def calc_delta_x (value , prior_result ):
777
+ return prior_result - (pt .erfcx (prior_result ) - value ) / (
778
+ 2 * prior_result * pt .erfcx (prior_result ) - 2 / pt .sqrt (np .pi )
779
+ )
780
+
781
+ result , updates = scan (
782
+ fn = calc_delta_x ,
783
+ outputs_info = pt .ones_like (x ),
784
+ non_sequences = value ,
785
+ n_steps = 10 ,
786
+ )
787
+ return result [- 1 ]
788
+
789
+
674
790
class LocTransform (RVTransform ):
675
791
name = "loc"
676
792
0 commit comments