Skip to content

Commit c8f3704

Browse files
committed
Remove non-generative incsubtensor logp inference
This is now properly done by PartiallyObservedRV
1 parent 48e56c3 commit c8f3704

File tree

3 files changed

+1
-159
lines changed

3 files changed

+1
-159
lines changed

pymc/logprob/rewriting.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
from collections import deque
3939
from collections.abc import Collection, Sequence
4040

41-
import pytensor.tensor as pt
42-
4341
from pytensor import config
4442
from pytensor.compile.mode import optdb
4543
from pytensor.graph.basic import (
@@ -84,7 +82,7 @@
8482
from pytensor.tensor.variable import TensorVariable
8583

8684
from pymc.logprob.abstract import MeasurableVariable
87-
from pymc.logprob.utils import DiracDelta, indices_from_subtensor
85+
from pymc.logprob.utils import DiracDelta
8886

8987
inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
9088
subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)
@@ -314,50 +312,6 @@ def remove_DiracDelta(fgraph, node):
314312
return [dd_val]
315313

316314

317-
@node_rewriter(inc_subtensor_ops)
318-
def incsubtensor_rv_replace(fgraph, node):
319-
r"""Replace `*IncSubtensor*` `Op`\s and their value variables for log-probability calculations.
320-
321-
This is used to derive the log-probability graph for ``Y[idx] = data``, where
322-
``Y`` is a `RandomVariable`, ``idx`` indices, and ``data`` some arbitrary data.
323-
324-
To compute the log-probability of a statement like ``Y[idx] = data``, we must
325-
first realize that our objective is equivalent to computing ``logprob(Y, z)``,
326-
where ``z = pt.set_subtensor(y[idx], data)`` and ``y`` is the value variable
327-
for ``Y``.
328-
329-
In other words, the log-probability for an `*IncSubtensor*` is the log-probability
330-
of the underlying `RandomVariable` evaluated at ``data`` for the indices
331-
given by ``idx`` and at the value variable for ``~idx``.
332-
333-
This provides a means of specifying "missing data", for instance.
334-
"""
335-
rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)
336-
337-
if rv_map_feature is None:
338-
return None # pragma: no cover
339-
340-
rv_var = node.outputs[0]
341-
if rv_var not in rv_map_feature.rv_values:
342-
return None # pragma: no cover
343-
344-
base_rv_var = node.inputs[0]
345-
346-
if not rv_map_feature.request_measurable([base_rv_var]):
347-
return None
348-
349-
data = node.inputs[1]
350-
idx = indices_from_subtensor(getattr(node.op, "idx_list", None), node.inputs[2:])
351-
352-
# Create a new value variable with the indices `idx` set to `data`
353-
value_var = rv_map_feature.rv_values[rv_var]
354-
new_value_var = pt.set_subtensor(value_var[idx], data)
355-
rv_map_feature.update_rv_maps(rv_var, new_value_var, base_rv_var)
356-
357-
# Return the `RandomVariable` being indexed
358-
return [base_rv_var]
359-
360-
361315
logprob_rewrites_db = SequenceDB()
362316
logprob_rewrites_db.name = "logprob_rewrites_db"
363317
# Introduce sigmoid. We do it before canonicalization so that useless mul are removed next
@@ -378,7 +332,6 @@ def incsubtensor_rv_replace(fgraph, node):
378332
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
379333
# "up" through the random/measurable variables and into their inputs.
380334
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
381-
measurable_ir_rewrites_db.register("incsubtensor_lift", incsubtensor_rv_replace, "basic")
382335

383336
# These rewrites are used to introduce specalized operations with better logprob graphs
384337
specialization_ir_rewrites_db = EquilibriumDB()

tests/logprob/test_basic.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@
4444

4545
from pytensor.graph.basic import ancestors, equal_computations
4646
from pytensor.tensor.random.op import RandomVariable
47-
from pytensor.tensor.subtensor import (
48-
AdvancedIncSubtensor,
49-
AdvancedIncSubtensor1,
50-
IncSubtensor,
51-
)
5247

5348
import pymc as pm
5449

@@ -173,20 +168,6 @@ def test_factorized_joint_logprob_diff_dims():
173168
assert exp_logp_val == pytest.approx(logp_val)
174169

175170

176-
def test_incsubtensor_original_values_output_dict():
177-
"""
178-
Test that the original un-incsubtensor value variable appears an the key of
179-
the logprob factor
180-
"""
181-
182-
base_rv = pt.random.normal(0, 1, size=2)
183-
rv = pt.set_subtensor(base_rv[0], 5)
184-
vv = rv.clone()
185-
186-
logp_dict = conditional_logp({rv: vv})
187-
assert vv in logp_dict
188-
189-
190171
def test_persist_inputs():
191172
"""Make sure we don't unnecessarily clone variables."""
192173
x = pt.scalar("x")
@@ -276,54 +257,6 @@ def test_joint_logp_basic():
276257
assert a_value_var in res_ancestors
277258

278259

279-
@pytest.mark.parametrize(
280-
"indices, size",
281-
[
282-
(slice(0, 2), 5),
283-
(np.r_[True, True, False, False, True], 5),
284-
(np.r_[0, 1, 4], 5),
285-
((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)),
286-
],
287-
)
288-
def test_joint_logp_incsubtensor(indices, size):
289-
"""Make sure we can compute a log-likelihood for ``Y[idx] = data`` where ``Y`` is univariate."""
290-
291-
mu = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size)
292-
data = mu[indices]
293-
sigma = 0.001
294-
rng = np.random.default_rng(232)
295-
a_val = rng.normal(mu, sigma, size=size).astype(pytensor.config.floatX)
296-
297-
rng = pytensor.shared(rng, borrow=False)
298-
a = pm.Normal.dist(mu, sigma, size=size, rng=rng)
299-
a_value_var = a.type()
300-
a.name = "a"
301-
302-
a_idx = pt.set_subtensor(a[indices], data)
303-
304-
assert isinstance(a_idx.owner.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1)
305-
306-
a_idx_value_var = a_idx.type()
307-
a_idx_value_var.name = "a_idx_value"
308-
309-
a_idx_logp = transformed_conditional_logp(
310-
(a_idx,),
311-
rvs_to_values={a_idx: a_value_var},
312-
rvs_to_transforms={},
313-
)
314-
315-
logp_vals = a_idx_logp[0].eval({a_value_var: a_val})
316-
317-
# The indices that were set should all have the same log-likelihood values,
318-
# because the values they were set to correspond to the unique means along
319-
# that dimension. This helps us confirm that the log-likelihood is
320-
# associating the assigned values with their correct parameters.
321-
a_val_idx = a_val.copy()
322-
a_val_idx[indices] = data
323-
exp_obs_logps = sp.norm.logpdf(a_val_idx, mu, sigma)
324-
np.testing.assert_almost_equal(logp_vals, exp_obs_logps)
325-
326-
327260
def test_model_unchanged_logprob_access():
328261
# Issue #5007
329262
with pm.Model() as model:

tests/logprob/test_rewriting.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,13 @@
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
3636

37-
import numpy as np
3837
import pytensor.tensor as pt
39-
import pytest
40-
import scipy.stats.distributions as sp
4138

4239
from pytensor.graph import ancestors
4340
from pytensor.graph.rewriting.basic import in2out
4441
from pytensor.graph.rewriting.utils import rewrite_graph
4542
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4643
from pytensor.tensor.subtensor import (
47-
AdvancedIncSubtensor,
48-
AdvancedIncSubtensor1,
49-
IncSubtensor,
5044
Subtensor,
5145
)
5246

@@ -105,41 +99,3 @@ def test_local_remove_TransformedVariable():
10599
[p_logp] = conditional_logp({p_rv: p_vv}, extra_rewrites=tr).values()
106100

107101
assert not any(isinstance(v.owner.op, TransformedValue) for v in ancestors([p_logp]) if v.owner)
108-
109-
110-
@pytest.mark.parametrize(
111-
"indices, size",
112-
[
113-
(slice(0, 2), 5),
114-
(np.r_[True, True, False, False, True], 5),
115-
(np.r_[0, 1, 4], 5),
116-
((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)),
117-
],
118-
)
119-
def test_joint_logprob_incsubtensor(indices, size):
120-
"""Make sure we can compute a joint log-probability for ``Y[idx] = data`` where ``Y`` is univariate."""
121-
122-
rng = np.random.RandomState(232)
123-
mu = np.power(10, np.arange(np.prod(size))).reshape(size)
124-
sigma = 0.001
125-
data = rng.normal(mu[indices], 1.0)
126-
y_val = rng.normal(mu, sigma, size=size)
127-
128-
Y_base_rv = pt.random.normal(mu, sigma, size=size)
129-
Y_rv = pt.set_subtensor(Y_base_rv[indices], data)
130-
Y_rv.name = "Y"
131-
y_value_var = Y_rv.clone()
132-
y_value_var.name = "y"
133-
134-
assert isinstance(Y_rv.owner.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1)
135-
136-
Y_rv_logp = conditional_logp({Y_rv: y_value_var})
137-
Y_rv_logp_combined = pt.add(*Y_rv_logp.values())
138-
139-
obs_logps = Y_rv_logp_combined.eval({y_value_var: y_val})
140-
141-
y_val_idx = y_val.copy()
142-
y_val_idx[indices] = data
143-
exp_obs_logps = sp.norm.logpdf(y_val_idx, mu, sigma)
144-
145-
np.testing.assert_almost_equal(obs_logps, exp_obs_logps)

0 commit comments

Comments
 (0)