47
47
from pytensor .graph .fg import FunctionGraph
48
48
from pytensor .graph .op import Op
49
49
from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
50
- from pytensor .scalar import Add , Exp , Log , Mul
50
+ from pytensor .scalar import Add , Exp , Log , Mul , Reciprocal
51
51
from pytensor .scan .op import Scan
52
- from pytensor .tensor .math import add , exp , log , mul
52
+ from pytensor .tensor .exceptions import NotScalarConstantError
53
+ from pytensor .tensor .math import add , exp , log , mul , reciprocal , true_div
53
54
from pytensor .tensor .rewriting .basic import (
54
55
register_specialize ,
55
56
register_stabilize ,
@@ -318,7 +319,7 @@ def apply(self, fgraph: FunctionGraph):
318
319
class MeasurableTransform (MeasurableElemwise ):
319
320
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
320
321
321
- valid_scalar_types = (Exp , Log , Add , Mul )
322
+ valid_scalar_types = (Exp , Log , Add , Mul , Reciprocal )
322
323
323
324
# Cannot use `transform` as name because it would clash with the property added by
324
325
# the `TransformValuesRewrite`
@@ -354,7 +355,36 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
354
355
return input_logprob + jacobian
355
356
356
357
357
- @node_rewriter ([exp , log , add , mul ])
358
+ @node_rewriter ([true_div ])
359
+ def measurable_div_to_reciprocal_product (fgraph , node ):
360
+ """Convert divisions involving `MeasurableVariable`s to product with reciprocal."""
361
+
362
+ measurable_vars = [
363
+ var for var in node .inputs if (var .owner and isinstance (var .owner .op , MeasurableVariable ))
364
+ ]
365
+ if not measurable_vars :
366
+ return None # pragma: no cover
367
+
368
+ rv_map_feature : Optional [PreserveRVMappings ] = getattr (fgraph , "preserve_rv_mappings" , None )
369
+ if rv_map_feature is None :
370
+ return None # pragma: no cover
371
+
372
+ # Only apply this rewrite if there is one unvalued MeasurableVariable involved
373
+ if all (measurable_var in rv_map_feature .rv_values for measurable_var in measurable_vars ):
374
+ return None # pragma: no cover
375
+
376
+ numerator , denominator = node .inputs
377
+
378
+ # Check if numerator is 1
379
+ try :
380
+ if at .get_scalar_constant_value (numerator ) == 1 :
381
+ return [at .reciprocal (denominator )]
382
+ except NotScalarConstantError :
383
+ pass
384
+ return [at .mul (numerator , at .reciprocal (denominator ))]
385
+
386
+
387
+ @node_rewriter ([exp , log , add , mul , reciprocal ])
358
388
def find_measurable_transforms (fgraph : FunctionGraph , node : Node ) -> Optional [List [Node ]]:
359
389
"""Find measurable transformations from Elemwise operators."""
360
390
@@ -414,6 +444,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
414
444
transform = ExpTransform ()
415
445
elif isinstance (scalar_op , Log ):
416
446
transform = LogTransform ()
447
+ elif isinstance (scalar_op , Reciprocal ):
448
+ transform = ReciprocalTransform ()
417
449
elif isinstance (scalar_op , Add ):
418
450
transform_inputs = (measurable_input , at .add (* other_inputs ))
419
451
transform = LocTransform (
@@ -436,6 +468,14 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
436
468
return [transform_out ]
437
469
438
470
471
+ measurable_ir_rewrites_db .register (
472
+ "measurable_div_to_reciprocal_product" ,
473
+ measurable_div_to_reciprocal_product ,
474
+ "basic" ,
475
+ "transform" ,
476
+ )
477
+
478
+
439
479
measurable_ir_rewrites_db .register (
440
480
"find_measurable_transforms" ,
441
481
find_measurable_transforms ,
@@ -507,6 +547,19 @@ def log_jac_det(self, value, *inputs):
507
547
return - at .log (value )
508
548
509
549
550
+ class ReciprocalTransform (RVTransform ):
551
+ name = "reciprocal"
552
+
553
+ def forward (self , value , * inputs ):
554
+ return at .reciprocal (value )
555
+
556
+ def backward (self , value , * inputs ):
557
+ return at .reciprocal (value )
558
+
559
+ def log_jac_det (self , value , * inputs ):
560
+ return - 2 * at .log (value )
561
+
562
+
510
563
class IntervalTransform (RVTransform ):
511
564
name = "interval"
512
565
0 commit comments