37
37
import abc
38
38
39
39
from copy import copy
40
- from functools import partial , singledispatch
41
40
from typing import Callable , Dict , List , Optional , Tuple , Union
42
41
43
42
import pytensor .tensor as at
69
68
from pymc .logprob .utils import walk_model
70
69
71
70
72
- @singledispatch
73
- def _default_transformed_rv (
74
- op : Op ,
75
- node : Node ,
76
- ) -> Optional [Apply ]:
77
- """Create a node for a transformed log-probability of a `MeasurableVariable`.
78
-
79
- This function dispatches on the type of `op`. If you want to implement
80
- new transforms for a `MeasurableVariable`, register a function on this
81
- dispatcher.
82
-
83
- """
84
- return None
85
-
86
-
87
71
class TransformedVariable (Op ):
88
72
"""A no-op that identifies a transform and its un-transformed input."""
89
73
@@ -136,13 +120,6 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
136
120
return at .log (at .abs (at .nlinalg .det (at .atleast_2d (jacobian (phi_inv , [value ])[0 ]))))
137
121
138
122
139
- class DefaultTransformSentinel :
140
- pass
141
-
142
-
143
- DEFAULT_TRANSFORM = DefaultTransformSentinel ()
144
-
145
-
146
123
@node_rewriter (tracks = None )
147
124
def transform_values (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
148
125
"""Apply transforms to value variables.
@@ -176,17 +153,12 @@ def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
176
153
177
154
if transform is None :
178
155
return None
179
- elif transform is DEFAULT_TRANSFORM :
180
- trans_node = _default_transformed_rv (node .op , node )
181
- if trans_node is None :
182
- return None
183
- transform = trans_node .op .transform
184
- else :
185
- new_op = _create_transformed_rv_op (node .op , transform )
186
- # Create a new `Apply` node and outputs
187
- trans_node = node .clone ()
188
- trans_node .op = new_op
189
- trans_node .outputs [rv_var_out_idx ].name = node .outputs [rv_var_out_idx ].name
156
+
157
+ new_op = _create_transformed_rv_op (node .op , transform )
158
+ # Create a new `Apply` node and outputs
159
+ trans_node = node .clone ()
160
+ trans_node .op = new_op
161
+ trans_node .outputs [rv_var_out_idx ].name = node .outputs [rv_var_out_idx ].name
190
162
191
163
# We now assume that the old value variable represents the *transformed space*.
192
164
# This means that we need to replace all instance of the old value variable
@@ -216,24 +188,22 @@ def on_attach(self, fgraph):
216
188
217
189
218
190
class TransformValuesRewrite (GraphRewriter ):
219
- r"""Transforms value variables according to a map and/or per-`RandomVariable` defaults ."""
191
+ r"""Transforms value variables according to a map."""
220
192
221
- default_transform_rewrite = in2out (transform_values , ignore_newtrees = True )
193
+ transform_rewrite = in2out (transform_values , ignore_newtrees = True )
222
194
223
195
def __init__ (
224
196
self ,
225
- values_to_transforms : Dict [
226
- TensorVariable , Union [RVTransform , DefaultTransformSentinel , None ]
227
- ],
197
+ values_to_transforms : Dict [TensorVariable , Union [RVTransform , None ]],
228
198
):
229
199
"""
230
200
Parameters
231
201
==========
232
202
values_to_transforms
233
203
Mapping between value variables and their transformations. Each
234
- value variable can be assigned one of `RVTransform`,
235
- ``DEFAULT_TRANSFORM``, or ``None``. If a transform is not specified
236
- for a specific value variable it will not be transformed.
204
+ value variable can be assigned one of `RVTransform`, or ``None``.
205
+ If a transform is not specified for a specific value variable it will
206
+ not be transformed.
237
207
238
208
"""
239
209
@@ -244,7 +214,7 @@ def add_requirements(self, fgraph):
244
214
fgraph .attach_feature (values_transforms_feature )
245
215
246
216
def apply (self , fgraph : FunctionGraph ):
247
- return self .default_transform_rewrite .rewrite (fgraph )
217
+ return self .transform_rewrite .rewrite (fgraph )
248
218
249
219
250
220
class MeasurableTransform (MeasurableElemwise ):
@@ -583,7 +553,6 @@ def _create_transformed_rv_op(
583
553
rv_op : Op ,
584
554
transform : RVTransform ,
585
555
* ,
586
- default : bool = False ,
587
556
cls_dict_extra : Optional [Dict ] = None ,
588
557
) -> Op :
589
558
"""Create a new transformed variable instance given a base `RandomVariable` `Op`.
@@ -600,8 +569,6 @@ def _create_transformed_rv_op(
600
569
The `RandomVariable` for which we want to construct a `TransformedRV`.
601
570
transform
602
571
The `RVTransform` for `rv_op`.
603
- default
604
- If ``False`` do not make `transform` the default transform for `rv_op`.
605
572
cls_dict_extra
606
573
Additional class members to add to the constructed `TransformedRV`.
607
574
@@ -642,85 +609,7 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
642
609
643
610
return logprob
644
611
645
- transform_op = rv_op_type if default else new_op_type
646
-
647
- @_default_transformed_rv .register (transform_op )
648
- def class_transformed_rv (op , node ):
649
- new_op = new_op_type ()
650
- res = new_op .make_node (* node .inputs )
651
- res .outputs [1 ].name = node .outputs [1 ].name
652
- return res
653
-
654
612
new_op = copy (rv_op )
655
613
new_op .__class__ = new_op_type
656
614
657
615
return new_op
658
-
659
-
660
- create_default_transformed_rv_op = partial (_create_transformed_rv_op , default = True )
661
-
662
-
663
- TransformedUniformRV = create_default_transformed_rv_op (
664
- at .random .uniform ,
665
- # inputs[3] = lower; inputs[4] = upper
666
- IntervalTransform (lambda * inputs : (inputs [3 ], inputs [4 ])),
667
- )
668
- TransformedParetoRV = create_default_transformed_rv_op (
669
- at .random .pareto ,
670
- # inputs[3] = alpha
671
- IntervalTransform (lambda * inputs : (inputs [3 ], None )),
672
- )
673
- TransformedTriangularRV = create_default_transformed_rv_op (
674
- at .random .triangular ,
675
- # inputs[3] = lower; inputs[5] = upper
676
- IntervalTransform (lambda * inputs : (inputs [3 ], inputs [5 ])),
677
- )
678
- TransformedHalfNormalRV = create_default_transformed_rv_op (
679
- at .random .halfnormal ,
680
- # inputs[3] = loc
681
- IntervalTransform (lambda * inputs : (inputs [3 ], None )),
682
- )
683
- TransformedWaldRV = create_default_transformed_rv_op (
684
- at .random .wald ,
685
- LogTransform (),
686
- )
687
- TransformedExponentialRV = create_default_transformed_rv_op (
688
- at .random .exponential ,
689
- LogTransform (),
690
- )
691
- TransformedLognormalRV = create_default_transformed_rv_op (
692
- at .random .lognormal ,
693
- LogTransform (),
694
- )
695
- TransformedHalfCauchyRV = create_default_transformed_rv_op (
696
- at .random .halfcauchy ,
697
- LogTransform (),
698
- )
699
- TransformedGammaRV = create_default_transformed_rv_op (
700
- at .random .gamma ,
701
- LogTransform (),
702
- )
703
- TransformedInvGammaRV = create_default_transformed_rv_op (
704
- at .random .invgamma ,
705
- LogTransform (),
706
- )
707
- TransformedChiSquareRV = create_default_transformed_rv_op (
708
- at .random .chisquare ,
709
- LogTransform (),
710
- )
711
- TransformedWeibullRV = create_default_transformed_rv_op (
712
- at .random .weibull ,
713
- LogTransform (),
714
- )
715
- TransformedBetaRV = create_default_transformed_rv_op (
716
- at .random .beta ,
717
- LogOddsTransform (),
718
- )
719
- TransformedVonMisesRV = create_default_transformed_rv_op (
720
- at .random .vonmises ,
721
- CircularTransform (),
722
- )
723
- TransformedDirichletRV = create_default_transformed_rv_op (
724
- at .random .dirichlet ,
725
- SimplexTransform (),
726
- )
0 commit comments