Skip to content

Commit ab680a5

Browse files
ArmavicaricardoV94
authored andcommitted
Remove test redundant with test_measurable_join_*
1 parent 5524d6a commit ab680a5

File tree

2 files changed

+2
-14
lines changed

2 files changed

+2
-14
lines changed

pymc/tests/logprob/test_joint_logprob.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -540,17 +540,3 @@ def test_hierarchical_obs_logp():
540540
ops = {a.owner.op for a in logp_ancestors if a.owner}
541541
assert len(ops) > 0
542542
assert not any(isinstance(o, RandomVariable) for o in ops)
543-
544-
545-
def test_logprob_join_constant_shapes():
546-
x = at.random.normal(size=5)
547-
y = at.random.normal(size=3)
548-
xy = at.join(x, y)
549-
xy_vv = at.vector("xy_vv")
550-
551-
xy_logp = pm.logp(xy, xy_vv)
552-
# This is what Aeppl does not do!
553-
assert_no_rvs(xy_logp)
554-
555-
f = pytensor.function([xy_vv], xy_logp)
556-
np.testing.assert_array_equal(f(np.zeros(8)), sp.norm.logpdf(np.zeros(8)))

pymc/tests/logprob/test_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate):
247247
else:
248248
base_logps = at.stack(base_logps, axis=axis)
249249
y_logp = joint_logprob({y_rv: y_vv}, sum=False)
250+
assert_no_rvs(y_logp)
250251

251252
base1_testval = base1_rv.eval()
252253
base2_testval = base2_rv.eval()
@@ -314,6 +315,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
314315
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1)
315316
base_logps = at.stack(base_logps, axis=axis_norm - 1)
316317
y_logp = joint_logprob({y_rv: y_vv}, sum=False)
318+
assert_no_rvs(y_logp)
317319

318320
base1_testval = base1_rv.eval()
319321
base2_testval = base2_rv.eval()

0 commit comments

Comments
 (0)