Skip to content

Commit 4c54b7d

Browse files
committed
Add rewrites for measurable negation and subtraction
1 parent 95c589a commit 4c54b7d

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

pymc/logprob/transforms.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from pytensor.scalar import Add, Exp, Log, Mul, Reciprocal
5151
from pytensor.scan.op import Scan
5252
from pytensor.tensor.exceptions import NotScalarConstantError
53-
from pytensor.tensor.math import add, exp, log, mul, reciprocal, true_div
53+
from pytensor.tensor.math import add, exp, log, mul, neg, reciprocal, sub, true_div
5454
from pytensor.tensor.rewriting.basic import (
5555
register_specialize,
5656
register_stabilize,
@@ -384,6 +384,46 @@ def measurable_div_to_reciprocal_product(fgraph, node):
384384
return [at.mul(numerator, at.reciprocal(denominator))]
385385

386386

387+
@node_rewriter([neg])
388+
def measurable_neg_to_product(fgraph, node):
389+
"""Convert negation of `MeasurableVariable`s to product with `-1`."""
390+
391+
inp = node.inputs[0]
392+
if not (inp.owner and isinstance(inp.owner.op, MeasurableVariable)):
393+
return None
394+
395+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
396+
if rv_map_feature is None:
397+
return None # pragma: no cover
398+
399+
# Only apply this rewrite if the variable is unvalued
400+
if inp in rv_map_feature.rv_values:
401+
return None # pragma: no cover
402+
403+
return [at.mul(inp, -1.0)]
404+
405+
406+
@node_rewriter([sub])
407+
def measurable_sub_to_neg(fgraph, node):
408+
"""Convert subtraction involving `MeasurableVariable`s to addition with neg"""
409+
measurable_vars = [
410+
var for var in node.inputs if (var.owner and isinstance(var.owner.op, MeasurableVariable))
411+
]
412+
if not measurable_vars:
413+
return None # pragma: no cover
414+
415+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
416+
if rv_map_feature is None:
417+
return None # pragma: no cover
418+
419+
# Only apply this rewrite if there is one unvalued MeasurableVariable involved
420+
if all(measurable_var in rv_map_feature.rv_values for measurable_var in measurable_vars):
421+
return None # pragma: no cover
422+
423+
minuend, subtrahend = node.inputs
424+
return [at.add(minuend, at.neg(subtrahend))]
425+
426+
387427
@node_rewriter([exp, log, add, mul, reciprocal])
388428
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
389429
"""Find measurable transformations from Elemwise operators."""
@@ -475,6 +515,19 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
475515
"transform",
476516
)
477517

518+
measurable_ir_rewrites_db.register(
519+
"measurable_neg_to_product",
520+
measurable_neg_to_product,
521+
"basic",
522+
"transform",
523+
)
524+
525+
measurable_ir_rewrites_db.register(
526+
"measurable_sub_to_neg",
527+
measurable_sub_to_neg,
528+
"basic",
529+
"transform",
530+
)
478531

479532
measurable_ir_rewrites_db.register(
480533
"find_measurable_transforms",

pymc/tests/logprob/test_transforms.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -664,17 +664,20 @@ def test_log_transform_rv():
664664

665665

666666
@pytest.mark.parametrize(
667-
"rv_size, loc_type",
667+
"rv_size, loc_type, addition",
668668
[
669-
(None, at.scalar),
670-
(2, at.vector),
671-
((2, 1), at.col),
669+
(None, at.scalar, True),
670+
(2, at.vector, False),
671+
((2, 1), at.col, True),
672672
],
673673
)
674-
def test_loc_transform_rv(rv_size, loc_type):
674+
def test_loc_transform_rv(rv_size, loc_type, addition):
675675

676676
loc = loc_type("loc")
677-
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
677+
if addition:
678+
y_rv = loc + at.random.normal(0, 1, size=rv_size, name="base_rv")
679+
else:
680+
y_rv = at.random.normal(0, 1, size=rv_size, name="base_rv") - at.neg(loc)
678681
y_rv.name = "y"
679682
y_vv = y_rv.clone()
680683

@@ -804,6 +807,27 @@ def test_reciprocal_rv_transform(numerator):
804807
)
805808

806809

810+
def test_negated_rv_transform():
811+
x_rv = -at.random.halfnormal()
812+
x_rv.name = "x"
813+
814+
x_vv = x_rv.clone()
815+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))
816+
817+
assert np.isclose(x_logp_fn(-1.5), sp.stats.halfnorm.logpdf(1.5))
818+
819+
820+
def test_subtracted_rv_transform():
821+
# Choose base RV that is assymetric around zero
822+
x_rv = 5.0 - at.random.normal(1.0)
823+
x_rv.name = "x"
824+
825+
x_vv = x_rv.clone()
826+
x_logp_fn = pytensor.function([x_vv], joint_logprob({x_rv: x_vv}))
827+
828+
assert np.isclose(x_logp_fn(7.3), sp.stats.norm.logpdf(5.0 - 7.3, 1.0))
829+
830+
807831
def test_scan_transform():
808832
"""Test that Scan valued variables can be transformed"""
809833

0 commit comments

Comments
 (0)