Skip to content

Commit 042c9f3

Browse files
committed
Derive probability for transforms with implicit broadcasting
A warning is issued as this graph is unlikely to be desired for most users.
1 parent 9f8ea52 commit 042c9f3

File tree

2 files changed

+73
-28
lines changed

2 files changed

+73
-28
lines changed

pymc/logprob/transforms.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
# SOFTWARE.
3636

3737
import abc
38+
import warnings
3839

3940
from copy import copy
4041
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -111,6 +112,7 @@
111112
cleanup_ir_rewrites_db,
112113
measurable_ir_rewrites_db,
113114
)
115+
from pymc.logprob.shape import measurable_broadcast
114116
from pymc.logprob.utils import check_potential_measurability
115117

116118

@@ -564,10 +566,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
564566

565567
scalar_op = node.op.scalar_op
566568
measurable_input_idx = 0
569+
measurable_input_broadcast = (
570+
measurable_input.type.broadcastable != node.default_output().type.broadcastable
571+
)
567572
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
568573
transform: RVTransform
569574

570-
transform_dict = {
575+
unary_transforms_dict = {
571576
Exp: ExpTransform(),
572577
Log: LogTransform(),
573578
Abs: AbsTransform(),
@@ -581,29 +586,49 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
581586
Erfc: ErfcTransform(),
582587
Erfcx: ErfcxTransform(),
583588
}
584-
transform = transform_dict.get(type(scalar_op), None)
585-
if isinstance(scalar_op, Pow):
586-
# We only allow for the base to be measurable
587-
if measurable_input_idx != 0:
588-
return None
589-
try:
590-
(power,) = other_inputs
591-
power = pt.get_underlying_scalar_constant_value(power).item()
592-
# Power needs to be a constant
593-
except NotScalarConstantError:
589+
transform = unary_transforms_dict.get(type(scalar_op), None)
590+
if transform is None:
591+
if isinstance(scalar_op, Pow):
592+
# We only allow for the base to be measurable
593+
if measurable_input_idx != 0:
594+
return None
595+
try:
596+
(power,) = other_inputs
597+
base_power = pt.get_underlying_scalar_constant_value(power).item()
598+
# Power needs to be a constant
599+
except NotScalarConstantError:
600+
return None
601+
transform_inputs = (measurable_input, power)
602+
transform = PowerTransform(power=base_power)
603+
elif isinstance(scalar_op, Add):
604+
transform_inputs = (measurable_input, pt.add(*other_inputs))
605+
transform = LocTransform(
606+
transform_args_fn=lambda *inputs: inputs[-1],
607+
)
608+
elif isinstance(scalar_op, Mul):
609+
transform_inputs = (measurable_input, pt.mul(*other_inputs))
610+
transform = ScaleTransform(
611+
transform_args_fn=lambda *inputs: inputs[-1],
612+
)
613+
else:
614+
raise TypeError(
615+
f"Scalar Op not supported: {scalar_op}. Rewrite should not have been triggered"
616+
) # pragma: no cover
617+
618+
if measurable_input_broadcast:
619+
# This rewrite logic only supports broadcasting for transforms with two inputs, where the first is measurable.
620+
# This covers all current cases, update if other cases are supported in the future.
621+
if len(transform_inputs) != 2 or measurable_input_idx != 0:
594622
return None
595-
transform_inputs = (measurable_input, power)
596-
transform = PowerTransform(power=power)
597-
elif isinstance(scalar_op, Add):
598-
transform_inputs = (measurable_input, pt.add(*other_inputs))
599-
transform = LocTransform(
600-
transform_args_fn=lambda *inputs: inputs[-1],
601-
)
602-
elif transform is None:
603-
transform_inputs = (measurable_input, pt.mul(*other_inputs))
604-
transform = ScaleTransform(
605-
transform_args_fn=lambda *inputs: inputs[-1],
623+
warnings.warn(
624+
"MeasurableTransform with implicit broadcasting detected. This corresponds to a potentially degenerate probability graph.\n"
625+
"If you did not intend this, make sure the base measurable variable is created with all the dimensions from the start."
626+
"Otherwise, an explicit `broadcast_to` operation can be used to silence this warning.\n",
627+
UserWarning,
606628
)
629+
measurable_input, other_input = transform_inputs
630+
measurable_input = measurable_broadcast(measurable_input, other_input.shape)
631+
transform_inputs = (measurable_input, other_input)
607632

608633
transform_op = MeasurableTransform(
609634
scalar_op=scalar_op,

tests/logprob/test_transforms.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -807,16 +807,36 @@ def test_discrete_rv_multinary_transform_fails():
807807
conditional_logp({y_rv: y_rv.clone()})
808808

809809

810-
@pytest.mark.xfail(reason="Check not implemented yet")
811-
def test_invalid_broadcasted_transform_rv_fails():
810+
@pytest.mark.filterwarnings("error") # Fail if unexpected warning is issued
811+
@pytest.mark.parametrize("implicit_broadcast", (True, False))
812+
def test_broadcasted_transform_rv(implicit_broadcast):
812813
loc = pt.vector("loc")
813-
y_rv = loc + pt.random.normal(0, 1, size=1, name="base_rv")
814+
base_rv = pt.random.normal(0, 1, size=1, name="base_rv")
815+
if implicit_broadcast:
816+
y_rv = loc + base_rv
817+
else:
818+
y_rv = loc + pt.broadcast_to(base_rv, shape=loc.shape)
814819
y_rv.name = "y"
815820
y_vv = y_rv.clone()
816821

817-
# This logp derivation should fail or count only once the values that are broadcasted
818-
logprob = logp(y_rv, y_vv)
819-
assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()
822+
if implicit_broadcast:
823+
with pytest.warns(UserWarning, match="implicit broadcasting detected"):
824+
logprob = logp(y_rv, y_vv)
825+
else:
826+
logprob = logp(y_rv, y_vv)
827+
logprob_fn = pytensor.function([loc, y_vv], logprob)
828+
829+
# All values must have the same offset from `loc`
830+
np.testing.assert_allclose(
831+
logprob_fn([1, 1, 1, 1], [0, 0, 0, 0]), sp.stats.norm.logpdf([0], loc=1)
832+
)
833+
np.testing.assert_allclose(
834+
logprob_fn([1, 2, 3, 4], [0, 1, 2, 3]), sp.stats.norm.logpdf([0], loc=1)
835+
)
836+
837+
# Otherwise probability is 0
838+
np.testing.assert_array_equal(logprob_fn([1, 1, 1, 1], [0, 0, 0, 1]), [-np.inf])
839+
np.testing.assert_array_equal(logprob_fn([1, 2, 3, 4], [0, 0, 0, 0]), [-np.inf])
820840

821841

822842
@pytest.mark.parametrize("numerator", (1.0, 2.0))

0 commit comments

Comments
 (0)