Skip to content

Commit 95c589a

Browse files
committed
Implement reciprocal measurable transform
Adds rewrite that converts divisions with measurable variables to product with reciprocals, making the reciprocal measurable transform more widely applicable.
1 parent 9d8fc34 commit 95c589a

File tree

2 files changed

+83
-10
lines changed

2 files changed

+83
-10
lines changed

pymc/logprob/transforms.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@
4747
from pytensor.graph.fg import FunctionGraph
4848
from pytensor.graph.op import Op
4949
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
50-
from pytensor.scalar import Add, Exp, Log, Mul
50+
from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal
5151
from pytensor.scan.op import Scan
52-
from pytensor.tensor.math import add, exp, log, mul
52+
from pytensor.tensor.exceptions import NotScalarConstantError
53+
from pytensor.tensor.math import add, exp, log, mul, reciprocal, true_div
5354
from pytensor.tensor.rewriting.basic import (
5455
register_specialize,
5556
register_stabilize,
@@ -318,7 +319,7 @@ def apply(self, fgraph: FunctionGraph):
318319
class MeasurableTransform(MeasurableElemwise):
319320
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""
320321

321-
valid_scalar_types = (Exp, Log, Add, Mul)
322+
valid_scalar_types = (Exp, Log, Add, Mul, Reciprocal)
322323

323324
# Cannot use `transform` as name because it would clash with the property added by
324325
# the `TransformValuesRewrite`
@@ -354,7 +355,36 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
354355
return input_logprob + jacobian
355356

356357

357-
@node_rewriter([exp, log, add, mul])
358+
@node_rewriter([true_div])
359+
def measurable_div_to_reciprocal_product(fgraph, node):
360+
"""Convert divisions involving `MeasurableVariable`s to product with reciprocal."""
361+
362+
measurable_vars = [
363+
var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable))
364+
]
365+
if not measurable_vars:
366+
return None # pragma: no cover
367+
368+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
369+
if rv_map_feature is None:
370+
return None # pragma: no cover
371+
372+
# Only apply this rewrite if there is one unvalued MeasurableVariable involved
373+
if all(measurable_var in rv_map_feature.rv_values for measurable_var in measurable_vars):
374+
return None # pragma: no cover
375+
376+
numerator, denominator = node.inputs
377+
378+
# Check if numerator is 1
379+
try:
380+
if at.get_scalar_constant_value(numerator) == 1:
381+
return [at.reciprocal(denominator)]
382+
except NotScalarConstantError:
383+
pass
384+
return [at.mul(numerator, at.reciprocal(denominator))]
385+
386+
387+
@node_rewriter([exp, log, add, mul, reciprocal])
358388
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
359389
"""Find measurable transformations from Elemwise operators."""
360390

@@ -414,6 +444,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
414444
transform = ExpTransform()
415445
elif isinstance(scalar_op, Log):
416446
transform = LogTransform()
447+
elif isinstance(scalar_op, Reciprocal):
448+
transform = ReciprocalTransform()
417449
elif isinstance(scalar_op, Add):
418450
transform_inputs = (measurable_input, at.add(*other_inputs))
419451
transform = LocTransform(
@@ -436,6 +468,14 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
436468
return [transform_out]
437469

438470

471+
measurable_ir_rewrites_db.register(
472+
"measurable_div_to_reciprocal_product",
473+
measurable_div_to_reciprocal_product,
474+
"basic",
475+
"transform",
476+
)
477+
478+
439479
measurable_ir_rewrites_db.register(
440480
"find_measurable_transforms",
441481
find_measurable_transforms,
@@ -507,6 +547,19 @@ def log_jac_det(self, value, *inputs):
507547
return -at.log(value)
508548

509549

550+
class ReciprocalTransform(RVTransform):
551+
name = "reciprocal"
552+
553+
def forward(self, value, *inputs):
554+
return at.reciprocal(value)
555+
556+
def backward(self, value, *inputs):
557+
return at.reciprocal(value)
558+
559+
def log_jac_det(self, value, *inputs):
560+
return -2 * at.log(value)
561+
562+
510563
class IntervalTransform(RVTransform):
511564
name = "interval"
512565

pymc/tests/logprob/test_transforms.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -692,17 +692,20 @@ def test_loc_transform_rv(rv_size, loc_type):
692692

693693

694694
@pytest.mark.parametrize(
695-
"rv_size, scale_type",
695+
"rv_size, scale_type, product",
696696
[
697-
(None, at.scalar),
698-
(1, at.TensorType("floatX", (True,))),
699-
((2, 3), at.matrix),
697+
(None, at.scalar, True),
698+
(1, at.TensorType("floatX", (True,)), True),
699+
((2, 3), at.matrix, False),
700700
],
701701
)
702-
def test_scale_transform_rv(rv_size, scale_type):
702+
def test_scale_transform_rv(rv_size, scale_type, product):
703703

704704
scale = scale_type("scale")
705-
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale
705+
if product:
706+
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") * scale
707+
else:
708+
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") / at.reciprocal(scale)
706709
y_rv.name = "y"
707710
y_vv = y_rv.clone()
708711

@@ -784,6 +787,23 @@ def test_invalid_broadcasted_transform_rv_fails():
784787
assert False, "Should have failed before"
785788

786789

790+
@pytest.mark.parametrize("numerator", (1.0, 2.0))
791+
def test_reciprocal_rv_transform(numerator):
792+
shape = 3
793+
scale = 5
794+
x_rv = numerator / at.random.gamma(shape, scale)
795+
x_rv.name = "x"
796+
797+
x_vv = x_rv.clone()
798+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))
799+
800+
x_test_val = 1.5
801+
assert np.isclose(
802+
x_logp_fn(x_test_val),
803+
sp.stats.invgamma(shape, scale=scale * numerator).logpdf(x_test_val),
804+
)
805+
806+
787807
def test_scan_transform():
788808
"""Test that Scan valued variables can be transformed"""
789809

0 commit comments

Comments
 (0)