Skip to content

Commit e74b58d

Browse files
ArmavicaricardoV94
authored andcommitted
Move joint_logprob to test/logprob/utils.py
1 parent c5e4497 commit e74b58d

12 files changed

+329
-95
lines changed

pymc/logprob/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
from pymc.logprob.abstract import logprob # isort: split
3838

39-
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob
39+
from pymc.logprob.joint_logprob import factorized_joint_logprob
4040

4141
# isort: off
4242
# Add rewrites to the DBs

pymc/logprob/joint_logprob.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from collections import deque
4040
from typing import Dict, Optional, Union
4141

42-
import pytensor.tensor as at
43-
4442
from pytensor import config
4543
from pytensor.graph.basic import graph_inputs, io_toposort
4644
from pytensor.graph.op import compute_test_value
@@ -221,33 +219,3 @@ def factorized_joint_logprob(
221219
)
222220

223221
return logprob_vars
224-
225-
226-
def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable]:
227-
"""Create a graph representing the joint log-probability/measure of a graph.
228-
229-
This function calls `factorized_joint_logprob` and returns the combined
230-
log-probability factors as a single graph.
231-
232-
Parameters
233-
----------
234-
sum: bool
235-
If ``True`` each factor is collapsed to a scalar via ``sum`` before
236-
being joined with the remaining factors. This may be necessary to
237-
avoid incorrect broadcasting among independent factors.
238-
239-
"""
240-
logprob = factorized_joint_logprob(*args, **kwargs)
241-
if not logprob:
242-
return None
243-
elif len(logprob) == 1:
244-
logprob = tuple(logprob.values())[0]
245-
if sum:
246-
return at.sum(logprob)
247-
else:
248-
return logprob
249-
else:
250-
if sum:
251-
return at.sum([at.sum(factor) for factor in logprob.values()])
252-
else:
253-
return at.add(*logprob.values())

pymc/tests/logprob/test_censoring.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
import scipy as sp
4242
import scipy.stats as st
4343

44-
from pymc.logprob import factorized_joint_logprob, joint_logprob
44+
from pymc.logprob import factorized_joint_logprob
4545
from pymc.logprob.transforms import LogTransform, TransformValuesRewrite
4646
from pymc.tests.helpers import assert_no_rvs
47+
from pymc.tests.logprob.utils import joint_logprob
4748

4849

4950
@pytensor.config.change_flags(compute_test_value="raise")

pymc/tests/logprob/test_composite_logprob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
import pytensor.tensor as at
4040
import scipy.stats as st
4141

42-
from pymc.logprob import joint_logprob
4342
from pymc.logprob.censoring import MeasurableClip
4443
from pymc.logprob.rewriting import construct_ir_fgraph
4544
from pymc.tests.helpers import assert_no_rvs
45+
from pymc.tests.logprob.utils import joint_logprob
4646

4747

4848
def test_scalar_clipped_mixture():

pymc/tests/logprob/test_cumsum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
import pytest
4141
import scipy.stats as st
4242

43-
from pymc.logprob import joint_logprob
4443
from pymc.tests.helpers import assert_no_rvs
44+
from pymc.tests.logprob.utils import joint_logprob
4545

4646

4747
@pytest.mark.parametrize(

pymc/tests/logprob/test_joint_logprob.py

Lines changed: 231 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,14 @@
4343
import scipy.stats.distributions as sp
4444

4545
from pytensor.graph.basic import ancestors, equal_computations
46-
from pytensor.tensor.subtensor import (
47-
AdvancedIncSubtensor,
48-
AdvancedIncSubtensor1,
49-
AdvancedSubtensor,
50-
AdvancedSubtensor1,
51-
IncSubtensor,
52-
Subtensor,
53-
)
46+
from pytensor.tensor.random.op import RandomVariable
47+
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
5448

5549
from pymc.logprob.abstract import logprob
56-
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob
50+
from pymc.logprob.joint_logprob import factorized_joint_logprob
5751
from pymc.logprob.utils import rvs_to_value_vars, walk_model
5852
from pymc.tests.helpers import assert_no_rvs
53+
from pymc.tests.logprob.utils import joint_logprob
5954

6055

6156
def test_joint_logprob_basic():
@@ -160,43 +155,6 @@ def test_joint_logprob_diff_dims():
160155
assert exp_logp_val == pytest.approx(logp_val)
161156

162157

163-
@pytest.mark.parametrize(
164-
"indices, size",
165-
[
166-
(slice(0, 2), 5),
167-
(np.r_[True, True, False, False, True], 5),
168-
(np.r_[0, 1, 4], 5),
169-
((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)),
170-
],
171-
)
172-
def test_joint_logprob_incsubtensor(indices, size):
173-
"""Make sure we can compute a joint log-probability for ``Y[idx] = data`` where ``Y`` is univariate."""
174-
175-
rng = np.random.RandomState(232)
176-
mu = np.power(10, np.arange(np.prod(size))).reshape(size)
177-
sigma = 0.001
178-
data = rng.normal(mu[indices], 1.0)
179-
y_val = rng.normal(mu, sigma, size=size)
180-
181-
Y_base_rv = at.random.normal(mu, sigma, size=size)
182-
Y_rv = at.set_subtensor(Y_base_rv[indices], data)
183-
Y_rv.name = "Y"
184-
y_value_var = Y_rv.clone()
185-
y_value_var.name = "y"
186-
187-
assert isinstance(Y_rv.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1))
188-
189-
Y_rv_logp = joint_logprob({Y_rv: y_value_var}, sum=False)
190-
191-
obs_logps = Y_rv_logp.eval({y_value_var: y_val})
192-
193-
y_val_idx = y_val.copy()
194-
y_val_idx[indices] = data
195-
exp_obs_logps = sp.norm.logpdf(y_val_idx, mu, sigma)
196-
197-
np.testing.assert_almost_equal(obs_logps, exp_obs_logps)
198-
199-
200158
def test_incsubtensor_original_values_output_dict():
201159
"""
202160
Test that the original un-incsubtensor value variable appears an the key of
@@ -308,3 +266,230 @@ def test_multiple_rvs_to_same_value_raises():
308266
msg = "More than one logprob factor was assigned to the value var x"
309267
with pytest.raises(ValueError, match=msg):
310268
joint_logprob({x_rv1: x, x_rv2: x})
269+
270+
271+
def test_get_scaling():
272+
273+
assert _get_scaling(None, (2, 3), 2).eval() == 1
274+
# ndim >=1 & ndim<1
275+
assert _get_scaling(45, (2, 3), 1).eval() == 22.5
276+
assert _get_scaling(45, (2, 3), 0).eval() == 45
277+
278+
# list or tuple tests
279+
# total_size contains other than Ellipsis, None and Int
280+
with pytest.raises(TypeError, match="Unrecognized `total_size` type"):
281+
_get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2)
282+
# check with Ellipsis
283+
with pytest.raises(ValueError, match="Double Ellipsis in `total_size` is restricted"):
284+
_get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2)
285+
with pytest.raises(
286+
ValueError,
287+
match="Length of `total_size` is too big, number of scalings is bigger that ndim",
288+
):
289+
_get_scaling([1, 2, 5, Ellipsis], (2, 3), 2)
290+
291+
assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1
292+
293+
assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960
294+
assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15
295+
# total_size with no Ellipsis (end = [ ])
296+
with pytest.raises(
297+
ValueError,
298+
match="Length of `total_size` is too big, number of scalings is bigger that ndim",
299+
):
300+
_get_scaling([1, 2, 5], (2, 3), 2)
301+
302+
assert _get_scaling([], (2, 3), 2).eval() == 1
303+
assert _get_scaling((), (2, 3), 2).eval() == 1
304+
# total_size invalid type
305+
with pytest.raises(
306+
TypeError,
307+
match="Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}",
308+
):
309+
_get_scaling({1, 2, 5}, (2, 3), 2)
310+
311+
# test with rvar from model graph
312+
with pm.Model() as m2:
313+
rv_var = pm.Uniform("a", 0.0, 1.0)
314+
total_size = []
315+
assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0
316+
317+
318+
def test_joint_logp_basic():
319+
"""Make sure we can compute a log-likelihood for a hierarchical model with transforms."""
320+
321+
with pm.Model() as m:
322+
a = pm.Uniform("a", 0.0, 1.0)
323+
c = pm.Normal("c")
324+
b_l = c * a + 2.0
325+
b = pm.Uniform("b", b_l, b_l + 1.0)
326+
327+
a_value_var = m.rvs_to_values[a]
328+
assert m.rvs_to_transforms[a]
329+
330+
b_value_var = m.rvs_to_values[b]
331+
assert m.rvs_to_transforms[b]
332+
333+
c_value_var = m.rvs_to_values[c]
334+
335+
(b_logp,) = joint_logp(
336+
(b,),
337+
rvs_to_values=m.rvs_to_values,
338+
rvs_to_transforms=m.rvs_to_transforms,
339+
rvs_to_total_sizes={},
340+
)
341+
342+
# There shouldn't be any `RandomVariable`s in the resulting graph
343+
assert_no_rvs(b_logp)
344+
345+
res_ancestors = list(walk_model((b_logp,)))
346+
assert b_value_var in res_ancestors
347+
assert c_value_var in res_ancestors
348+
assert a_value_var in res_ancestors
349+
350+
351+
def test_joint_logp_subtensor():
352+
"""Make sure we can compute a log-likelihood for ``Y[I]`` where ``Y`` and ``I`` are random variables."""
353+
354+
size = 5
355+
356+
mu_base = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size)
357+
mu = np.stack([mu_base, -mu_base])
358+
sigma = 0.001
359+
rng = pytensor.shared(np.random.RandomState(232), borrow=True)
360+
361+
A_rv = pm.Normal.dist(mu, sigma, rng=rng)
362+
A_rv.name = "A"
363+
364+
p = 0.5
365+
366+
I_rv = pm.Bernoulli.dist(p, size=size, rng=rng)
367+
I_rv.name = "I"
368+
369+
A_idx = A_rv[I_rv, at.ogrid[A_rv.shape[-1] :]]
370+
371+
assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
372+
373+
A_idx_value_var = A_idx.type()
374+
A_idx_value_var.name = "A_idx_value"
375+
376+
I_value_var = I_rv.type()
377+
I_value_var.name = "I_value"
378+
379+
A_idx_logps = joint_logp(
380+
(A_idx, I_rv),
381+
rvs_to_values={A_idx: A_idx_value_var, I_rv: I_value_var},
382+
rvs_to_transforms={},
383+
rvs_to_total_sizes={},
384+
)
385+
A_idx_logp = at.add(*A_idx_logps)
386+
387+
logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp)
388+
389+
# The compiled graph should not contain any `RandomVariables`
390+
assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0])
391+
392+
decimals = select_by_precision(float64=6, float32=4)
393+
394+
for i in range(10):
395+
bern_sp = sp.bernoulli(p)
396+
I_value = bern_sp.rvs(size=size).astype(I_rv.dtype)
397+
398+
norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma)
399+
A_idx_value = norm_sp.rvs().astype(A_idx.dtype)
400+
401+
exp_obs_logps = norm_sp.logpdf(A_idx_value)
402+
exp_obs_logps += bern_sp.logpmf(I_value)
403+
404+
logp_vals = logp_vals_fn(A_idx_value, I_value)
405+
406+
np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals)
407+
408+
409+
def test_logp_helper():
410+
value = at.vector("value")
411+
x = pm.Normal.dist(0, 1)
412+
413+
x_logp = pm.logp(x, value)
414+
np.testing.assert_almost_equal(x_logp.eval({value: [0, 1]}), sp.norm(0, 1).logpdf([0, 1]))
415+
416+
x_logp = pm.logp(x, [0, 1])
417+
np.testing.assert_almost_equal(x_logp.eval(), sp.norm(0, 1).logpdf([0, 1]))
418+
419+
420+
def test_logp_helper_derived_rv():
421+
assert np.isclose(
422+
pm.logp(at.exp(pm.Normal.dist()), 5).eval(),
423+
pm.logp(pm.LogNormal.dist(), 5).eval(),
424+
)
425+
426+
427+
def test_logp_helper_exceptions():
428+
with pytest.raises(TypeError, match="When RV is not a pure distribution"):
429+
pm.logp(at.exp(pm.Normal.dist()), [1, 2])
430+
431+
with pytest.raises(NotImplementedError, match="PyMC could not infer logp of input variable"):
432+
pm.logp(at.cos(pm.Normal.dist()), 1)
433+
434+
435+
def test_model_unchanged_logprob_access():
436+
# Issue #5007
437+
with pm.Model() as model:
438+
a = pm.Normal("a")
439+
c = pm.Uniform("c", lower=a - 1, upper=1)
440+
441+
original_inputs = set(pytensor.graph.graph_inputs([c]))
442+
# Extract model.logp
443+
model.logp()
444+
new_inputs = set(pytensor.graph.graph_inputs([c]))
445+
assert original_inputs == new_inputs
446+
447+
448+
def test_unexpected_rvs():
449+
with pm.Model() as model:
450+
x = pm.Normal("x")
451+
y = pm.CustomDist("y", logp=lambda *args: x)
452+
453+
with pytest.raises(ValueError, match="^Random variables detected in the logp graph"):
454+
model.logp()
455+
456+
457+
def test_hierarchical_logp():
458+
"""Make sure there are no random variables in a model's log-likelihood graph."""
459+
with pm.Model() as m:
460+
x = pm.Uniform("x", lower=0, upper=1)
461+
y = pm.Uniform("y", lower=0, upper=x)
462+
463+
logp_ancestors = list(ancestors([m.logp()]))
464+
ops = {a.owner.op for a in logp_ancestors if a.owner}
465+
assert len(ops) > 0
466+
assert not any(isinstance(o, RandomVariable) for o in ops)
467+
assert m.rvs_to_values[x] in logp_ancestors
468+
assert m.rvs_to_values[y] in logp_ancestors
469+
470+
471+
def test_hierarchical_obs_logp():
472+
obs = np.array([0.5, 0.4, 5, 2])
473+
474+
with pm.Model() as model:
475+
x = pm.Uniform("x", 0, 1, observed=obs)
476+
pm.Uniform("y", x, 2, observed=obs)
477+
478+
logp_ancestors = list(ancestors([model.logp()]))
479+
ops = {a.owner.op for a in logp_ancestors if a.owner}
480+
assert len(ops) > 0
481+
assert not any(isinstance(o, RandomVariable) for o in ops)
482+
483+
484+
def test_logprob_join_constant_shapes():
485+
x = at.random.normal(size=5)
486+
y = at.random.normal(size=3)
487+
xy = at.join(x, y)
488+
xy_vv = at.vector("xy_vv")
489+
490+
xy_logp = pm.logp(xy, xy_vv)
491+
# This is what Aeppl does not do!
492+
assert_no_rvs(xy_logp)
493+
494+
f = pytensor.function([xy_vv], xy_logp)
495+
np.testing.assert_array_equal(f(np.zeros(8)), sp.norm.logpdf(np.zeros(8)))

pymc/tests/logprob/test_mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@
4545
from pytensor.tensor.shape import shape_tuple
4646
from pytensor.tensor.subtensor import as_index_constant
4747

48-
from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob
48+
from pymc.logprob.joint_logprob import factorized_joint_logprob
4949
from pymc.logprob.mixture import MixtureRV, expand_indices
5050
from pymc.logprob.rewriting import construct_ir_fgraph
5151
from pymc.logprob.utils import dirac_delta
5252
from pymc.tests.helpers import assert_no_rvs
53-
from pymc.tests.logprob.utils import scipy_logprob
53+
from pymc.tests.logprob.utils import joint_logprob, scipy_logprob
5454

5555

5656
def test_mixture_basics():

0 commit comments

Comments
 (0)