Skip to content

Commit 585962d

Browse files
committed
Remove unused naive_bcast_rv_lift rewrite
1 parent 04fe3cd commit 585962d

File tree

2 files changed

+1
-102
lines changed

2 files changed

+1
-102
lines changed

pymc/logprob/tensor.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,14 @@
3737

3838
from pathlib import Path
3939

40-
import pytensor
41-
4240
from pytensor import tensor as pt
43-
from pytensor.graph.fg import FunctionGraph
44-
from pytensor.graph.op import compute_test_value
4541
from pytensor.graph.rewriting.basic import node_rewriter
4642
from pytensor.tensor import TensorVariable
47-
from pytensor.tensor.basic import Alloc, Join, MakeVector
43+
from pytensor.tensor.basic import Join, MakeVector
4844
from pytensor.tensor.elemwise import DimShuffle
4945
from pytensor.tensor.random.op import RandomVariable
5046
from pytensor.tensor.random.rewriting import (
5147
local_dimshuffle_rv_lift,
52-
local_rv_size_lift,
5348
)
5449

5550
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
@@ -62,68 +57,6 @@
6257
from pymc.pytensorf import constant_fold
6358

6459

65-
@node_rewriter([Alloc])
66-
def naive_bcast_rv_lift(fgraph: FunctionGraph, node):
67-
"""Lift an ``Alloc`` through a ``RandomVariable`` ``Op``.
68-
69-
XXX: This implementation simply broadcasts the ``RandomVariable``'s
70-
parameters, which won't always work (e.g. multivariate distributions).
71-
72-
TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to
73-
determine which dimensions of each parameter need to be broadcasted.
74-
Also, this doesn't need to remove ``size`` to perform the lifting, like it
75-
currently does.
76-
"""
77-
78-
if not (
79-
isinstance(node.op, Alloc)
80-
and node.inputs[0].owner
81-
and isinstance(node.inputs[0].owner.op, RandomVariable)
82-
):
83-
return None # pragma: no cover
84-
85-
bcast_shape = node.inputs[1:]
86-
87-
rv_var = node.inputs[0]
88-
rv_node = rv_var.owner
89-
90-
if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars:
91-
return None # pragma: no cover
92-
93-
# Do not replace RV if it is associated with a value variable
94-
rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)
95-
if rv_map_feature is not None and rv_var in rv_map_feature.rv_values:
96-
return None
97-
98-
if not bcast_shape:
99-
# The `Alloc` is broadcasting a scalar to a scalar (i.e. doing nothing)
100-
assert rv_var.ndim == 0
101-
return [rv_var]
102-
103-
size_lift_res = local_rv_size_lift.transform(fgraph, rv_node)
104-
if size_lift_res is None:
105-
lifted_node = rv_node
106-
else:
107-
_, lifted_rv = size_lift_res
108-
lifted_node = lifted_rv.owner
109-
110-
rng, size, *dist_params = lifted_node.inputs
111-
112-
new_dist_params = [
113-
pt.broadcast_to(
114-
param,
115-
pt.broadcast_shape(tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True),
116-
)
117-
for param in dist_params
118-
]
119-
bcasted_node = lifted_node.op.make_node(rng, size, *new_dist_params)
120-
121-
if pytensor.config.compute_test_value != "off":
122-
compute_test_value(bcasted_node)
123-
124-
return [bcasted_node.outputs[1]]
125-
126-
12760
class MeasurableMakeVector(MeasurableOp, MakeVector):
12861
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
12962

tests/logprob/test_tensor.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,47 +40,13 @@
4040

4141
from pytensor import tensor as pt
4242
from pytensor.graph import RewriteDatabaseQuery
43-
from pytensor.graph.rewriting.basic import in2out
44-
from pytensor.graph.rewriting.utils import rewrite_graph
45-
from pytensor.tensor.basic import Alloc
4643
from scipy import stats as st
4744

4845
from pymc.logprob.basic import conditional_logp, logp
4946
from pymc.logprob.rewriting import logprob_rewrites_db
50-
from pymc.logprob.tensor import naive_bcast_rv_lift
5147
from pymc.testing import assert_no_rvs
5248

5349

54-
def test_naive_bcast_rv_lift():
55-
r"""Make sure `naive_bcast_rv_lift` can handle useless scalar `Alloc`\s."""
56-
X_rv = pt.random.normal()
57-
Z_at = Alloc()(X_rv, *())
58-
59-
# Make sure we're testing what we intend to test
60-
assert isinstance(Z_at.owner.op, Alloc)
61-
62-
res = rewrite_graph(Z_at, custom_rewrite=in2out(naive_bcast_rv_lift), clone=False)
63-
assert res is X_rv
64-
65-
66-
def test_naive_bcast_rv_lift_valued_var():
67-
r"""Check that `naive_bcast_rv_lift` won't touch valued variables"""
68-
69-
x_rv = pt.random.normal(name="x")
70-
broadcasted_x_rv = pt.broadcast_to(x_rv, (2,))
71-
72-
y_rv = pt.random.normal(broadcasted_x_rv, name="y")
73-
74-
x_vv = x_rv.clone()
75-
y_vv = y_rv.clone()
76-
logp_map = conditional_logp({x_rv: x_vv, y_rv: y_vv})
77-
assert x_vv in logp_map
78-
assert y_vv in logp_map
79-
assert len(logp_map) == 2
80-
assert np.allclose(logp_map[x_vv].eval({x_vv: 0}), st.norm(0).logpdf(0))
81-
assert np.allclose(logp_map[y_vv].eval({x_vv: 0, y_vv: [0, 0]}), st.norm(0).logpdf([0, 0]))
82-
83-
8450
@pytest.mark.xfail(RuntimeError, reason="logprob for broadcasted RVs not implemented")
8551
def test_bcast_rv_logp():
8652
"""Test that derived logp for broadcasted RV is correct"""

0 commit comments

Comments
 (0)