Skip to content

Commit 419af06

Browse files
LukeLBLuke LB
and
Luke LB
authored
Support logp derivation of power(base, rv) (#6962)
Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Luke LB <[email protected]>
1 parent c53277b commit 419af06

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

pymc/logprob/transforms.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@
127127
cleanup_ir_rewrites_db,
128128
measurable_ir_rewrites_db,
129129
)
130-
from pymc.logprob.utils import check_potential_measurability
130+
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
131131

132132

133133
class TransformedVariable(Op):
@@ -617,6 +617,21 @@ def measurable_special_exp_to_exp(fgraph, node):
617617
return [1 / (1 + pt.exp(-inp))]
618618

619619

620+
@node_rewriter([pow])
621+
def measurable_power_exponent_to_exp(fgraph, node):
622+
"""Convert power(base, rv) of `MeasurableVariable`s to exp(log(base) * rv) form."""
623+
base, inp_exponent = node.inputs
624+
625+
# When the base is measurable we have `power(rv, exponent)`, which should be handled by `PowerTransform` and needs no further rewrite.
626+
# Here we change only the cases where exponent is measurable `power(base, rv)` which is not supported by the `PowerTransform`
627+
if check_potential_measurability([base], fgraph.preserve_rv_mappings.rv_values.keys()):
628+
return None
629+
630+
base = CheckParameterValue("base >= 0")(base, pt.all(pt.ge(base, 0.0)))
631+
632+
return [pt.exp(pt.log(base) * inp_exponent)]
633+
634+
620635
@node_rewriter(
621636
[
622637
exp,
@@ -693,7 +708,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
693708
try:
694709
(power,) = other_inputs
695710
power = pt.get_underlying_scalar_constant_value(power).item()
696-
# Power needs to be a constant
711+
# Power needs to be a constant, if not then proceed to the other case power(base, rv)
697712
except NotScalarConstantError:
698713
return None
699714
transform_inputs = (measurable_input, power)
@@ -769,6 +784,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
769784
"transform",
770785
)
771786

787+
measurable_ir_rewrites_db.register(
788+
"measurable_power_expotent_to_exp",
789+
measurable_power_exponent_to_exp,
790+
"basic",
791+
"transform",
792+
)
772793

773794
measurable_ir_rewrites_db.register(
774795
"find_measurable_transforms",

tests/logprob/test_transforms.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
TransformValuesMapping,
7373
TransformValuesRewrite,
7474
)
75+
from pymc.logprob.utils import ParameterValueError
7576
from pymc.testing import Rplusbig, Vector, assert_no_rvs
7677
from tests.distributions.test_transform import check_jacobian_det
7778

@@ -1159,6 +1160,61 @@ def test_special_log_exp_transforms(transform):
11591160
assert equal_computations([logp_test], [logp_ref])
11601161

11611162

1163+
def test_measurable_power_exponent_with_constant_base():
1164+
# test power(2, rv) = exp2(rv)
1165+
# test negative base fails
1166+
x_rv_pow = pt.pow(2, pt.random.normal())
1167+
x_rv_exp2 = pt.exp2(pt.random.normal())
1168+
1169+
x_vv_pow = x_rv_pow.clone()
1170+
x_vv_exp2 = x_rv_exp2.clone()
1171+
1172+
x_logp_fn_pow = pytensor.function([x_vv_pow], pt.sum(logp(x_rv_pow, x_vv_pow)))
1173+
x_logp_fn_exp2 = pytensor.function([x_vv_exp2], pt.sum(logp(x_rv_exp2, x_vv_exp2)))
1174+
1175+
np.testing.assert_allclose(x_logp_fn_pow(0.1), x_logp_fn_exp2(0.1))
1176+
1177+
with pytest.raises(ParameterValueError, match="base >= 0"):
1178+
x_rv_neg = pt.pow(-2, pt.random.normal())
1179+
x_vv_neg = x_rv_neg.clone()
1180+
logp(x_rv_neg, x_vv_neg)
1181+
1182+
1183+
def test_measurable_power_exponent_with_variable_base():
1184+
# test with RV when logp(<0) we raise error
1185+
base_rv = pt.random.normal([2])
1186+
x_raw_rv = pt.random.normal()
1187+
x_rv = pt.power(base_rv, x_raw_rv)
1188+
1189+
x_rv.name = "x"
1190+
base_rv.name = "base"
1191+
base_vv = base_rv.clone()
1192+
x_vv = x_rv.clone()
1193+
1194+
res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
1195+
x_logp = res[x_vv]
1196+
logp_vals_fn = pytensor.function([base_vv, x_vv], x_logp)
1197+
1198+
with pytest.raises(ParameterValueError, match="base >= 0"):
1199+
logp_vals_fn(np.array([-2]), np.array([2]))
1200+
1201+
1202+
def test_base_exponent_non_measurable():
1203+
# test dual sources of measuravility fails
1204+
base_rv = pt.random.normal([2])
1205+
x_raw_rv = pt.random.normal()
1206+
x_rv = pt.power(base_rv, x_raw_rv)
1207+
x_rv.name = "x"
1208+
1209+
x_vv = x_rv.clone()
1210+
1211+
with pytest.raises(
1212+
RuntimeError,
1213+
match="The logprob terms of the following value variables could not be derived: {x}",
1214+
):
1215+
conditional_logp({x_rv: x_vv})
1216+
1217+
11621218
@pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])])
11631219
@pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])])
11641220
def test_multivariate_rv_transform(shift, scale):

0 commit comments

Comments
 (0)