|
43 | 43 | import scipy.stats.distributions as sp
|
44 | 44 |
|
45 | 45 | from pytensor.graph.basic import ancestors, equal_computations
|
46 |
| -from pytensor.tensor.subtensor import ( |
47 |
| - AdvancedIncSubtensor, |
48 |
| - AdvancedIncSubtensor1, |
49 |
| - AdvancedSubtensor, |
50 |
| - AdvancedSubtensor1, |
51 |
| - IncSubtensor, |
52 |
| - Subtensor, |
53 |
| -) |
| 46 | +from pytensor.tensor.random.op import RandomVariable |
| 47 | +from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor |
54 | 48 |
|
55 | 49 | from pymc.logprob.abstract import logprob
|
56 |
| -from pymc.logprob.joint_logprob import factorized_joint_logprob, joint_logprob |
| 50 | +from pymc.logprob.joint_logprob import factorized_joint_logprob |
57 | 51 | from pymc.logprob.utils import rvs_to_value_vars, walk_model
|
58 | 52 | from pymc.tests.helpers import assert_no_rvs
|
| 53 | +from pymc.tests.logprob.utils import joint_logprob |
59 | 54 |
|
60 | 55 |
|
61 | 56 | def test_joint_logprob_basic():
|
@@ -160,43 +155,6 @@ def test_joint_logprob_diff_dims():
|
160 | 155 | assert exp_logp_val == pytest.approx(logp_val)
|
161 | 156 |
|
162 | 157 |
|
163 |
| -@pytest.mark.parametrize( |
164 |
| - "indices, size", |
165 |
| - [ |
166 |
| - (slice(0, 2), 5), |
167 |
| - (np.r_[True, True, False, False, True], 5), |
168 |
| - (np.r_[0, 1, 4], 5), |
169 |
| - ((np.array([0, 1, 4]), np.array([0, 1, 4])), (5, 5)), |
170 |
| - ], |
171 |
| -) |
172 |
| -def test_joint_logprob_incsubtensor(indices, size): |
173 |
| - """Make sure we can compute a joint log-probability for ``Y[idx] = data`` where ``Y`` is univariate.""" |
174 |
| - |
175 |
| - rng = np.random.RandomState(232) |
176 |
| - mu = np.power(10, np.arange(np.prod(size))).reshape(size) |
177 |
| - sigma = 0.001 |
178 |
| - data = rng.normal(mu[indices], 1.0) |
179 |
| - y_val = rng.normal(mu, sigma, size=size) |
180 |
| - |
181 |
| - Y_base_rv = at.random.normal(mu, sigma, size=size) |
182 |
| - Y_rv = at.set_subtensor(Y_base_rv[indices], data) |
183 |
| - Y_rv.name = "Y" |
184 |
| - y_value_var = Y_rv.clone() |
185 |
| - y_value_var.name = "y" |
186 |
| - |
187 |
| - assert isinstance(Y_rv.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) |
188 |
| - |
189 |
| - Y_rv_logp = joint_logprob({Y_rv: y_value_var}, sum=False) |
190 |
| - |
191 |
| - obs_logps = Y_rv_logp.eval({y_value_var: y_val}) |
192 |
| - |
193 |
| - y_val_idx = y_val.copy() |
194 |
| - y_val_idx[indices] = data |
195 |
| - exp_obs_logps = sp.norm.logpdf(y_val_idx, mu, sigma) |
196 |
| - |
197 |
| - np.testing.assert_almost_equal(obs_logps, exp_obs_logps) |
198 |
| - |
199 |
| - |
200 | 158 | def test_incsubtensor_original_values_output_dict():
|
201 | 159 | """
|
202 | 160 | Test that the original un-incsubtensor value variable appears an the key of
|
@@ -308,3 +266,230 @@ def test_multiple_rvs_to_same_value_raises():
|
308 | 266 | msg = "More than one logprob factor was assigned to the value var x"
|
309 | 267 | with pytest.raises(ValueError, match=msg):
|
310 | 268 | joint_logprob({x_rv1: x, x_rv2: x})
|
| 269 | + |
| 270 | + |
| 271 | +def test_get_scaling(): |
| 272 | + |
| 273 | + assert _get_scaling(None, (2, 3), 2).eval() == 1 |
| 274 | + # ndim >=1 & ndim<1 |
| 275 | + assert _get_scaling(45, (2, 3), 1).eval() == 22.5 |
| 276 | + assert _get_scaling(45, (2, 3), 0).eval() == 45 |
| 277 | + |
| 278 | + # list or tuple tests |
| 279 | + # total_size contains other than Ellipsis, None and Int |
| 280 | + with pytest.raises(TypeError, match="Unrecognized `total_size` type"): |
| 281 | + _get_scaling([2, 4, 5, 9, 11.5], (2, 3), 2) |
| 282 | + # check with Ellipsis |
| 283 | + with pytest.raises(ValueError, match="Double Ellipsis in `total_size` is restricted"): |
| 284 | + _get_scaling([1, 2, 5, Ellipsis, Ellipsis], (2, 3), 2) |
| 285 | + with pytest.raises( |
| 286 | + ValueError, |
| 287 | + match="Length of `total_size` is too big, number of scalings is bigger that ndim", |
| 288 | + ): |
| 289 | + _get_scaling([1, 2, 5, Ellipsis], (2, 3), 2) |
| 290 | + |
| 291 | + assert _get_scaling([Ellipsis], (2, 3), 2).eval() == 1 |
| 292 | + |
| 293 | + assert _get_scaling([4, 5, 9, Ellipsis, 32, 12], (2, 3, 2), 5).eval() == 960 |
| 294 | + assert _get_scaling([4, 5, 9, Ellipsis], (2, 3, 2), 5).eval() == 15 |
| 295 | + # total_size with no Ellipsis (end = [ ]) |
| 296 | + with pytest.raises( |
| 297 | + ValueError, |
| 298 | + match="Length of `total_size` is too big, number of scalings is bigger that ndim", |
| 299 | + ): |
| 300 | + _get_scaling([1, 2, 5], (2, 3), 2) |
| 301 | + |
| 302 | + assert _get_scaling([], (2, 3), 2).eval() == 1 |
| 303 | + assert _get_scaling((), (2, 3), 2).eval() == 1 |
| 304 | + # total_size invalid type |
| 305 | + with pytest.raises( |
| 306 | + TypeError, |
| 307 | + match="Unrecognized `total_size` type, expected int or list of ints, got {1, 2, 5}", |
| 308 | + ): |
| 309 | + _get_scaling({1, 2, 5}, (2, 3), 2) |
| 310 | + |
| 311 | + # test with rvar from model graph |
| 312 | + with pm.Model() as m2: |
| 313 | + rv_var = pm.Uniform("a", 0.0, 1.0) |
| 314 | + total_size = [] |
| 315 | + assert _get_scaling(total_size, shape=rv_var.shape, ndim=rv_var.ndim).eval() == 1.0 |
| 316 | + |
| 317 | + |
| 318 | +def test_joint_logp_basic(): |
| 319 | + """Make sure we can compute a log-likelihood for a hierarchical model with transforms.""" |
| 320 | + |
| 321 | + with pm.Model() as m: |
| 322 | + a = pm.Uniform("a", 0.0, 1.0) |
| 323 | + c = pm.Normal("c") |
| 324 | + b_l = c * a + 2.0 |
| 325 | + b = pm.Uniform("b", b_l, b_l + 1.0) |
| 326 | + |
| 327 | + a_value_var = m.rvs_to_values[a] |
| 328 | + assert m.rvs_to_transforms[a] |
| 329 | + |
| 330 | + b_value_var = m.rvs_to_values[b] |
| 331 | + assert m.rvs_to_transforms[b] |
| 332 | + |
| 333 | + c_value_var = m.rvs_to_values[c] |
| 334 | + |
| 335 | + (b_logp,) = joint_logp( |
| 336 | + (b,), |
| 337 | + rvs_to_values=m.rvs_to_values, |
| 338 | + rvs_to_transforms=m.rvs_to_transforms, |
| 339 | + rvs_to_total_sizes={}, |
| 340 | + ) |
| 341 | + |
| 342 | + # There shouldn't be any `RandomVariable`s in the resulting graph |
| 343 | + assert_no_rvs(b_logp) |
| 344 | + |
| 345 | + res_ancestors = list(walk_model((b_logp,))) |
| 346 | + assert b_value_var in res_ancestors |
| 347 | + assert c_value_var in res_ancestors |
| 348 | + assert a_value_var in res_ancestors |
| 349 | + |
| 350 | + |
| 351 | +def test_joint_logp_subtensor(): |
| 352 | + """Make sure we can compute a log-likelihood for ``Y[I]`` where ``Y`` and ``I`` are random variables.""" |
| 353 | + |
| 354 | + size = 5 |
| 355 | + |
| 356 | + mu_base = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size) |
| 357 | + mu = np.stack([mu_base, -mu_base]) |
| 358 | + sigma = 0.001 |
| 359 | + rng = pytensor.shared(np.random.RandomState(232), borrow=True) |
| 360 | + |
| 361 | + A_rv = pm.Normal.dist(mu, sigma, rng=rng) |
| 362 | + A_rv.name = "A" |
| 363 | + |
| 364 | + p = 0.5 |
| 365 | + |
| 366 | + I_rv = pm.Bernoulli.dist(p, size=size, rng=rng) |
| 367 | + I_rv.name = "I" |
| 368 | + |
| 369 | + A_idx = A_rv[I_rv, at.ogrid[A_rv.shape[-1] :]] |
| 370 | + |
| 371 | + assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) |
| 372 | + |
| 373 | + A_idx_value_var = A_idx.type() |
| 374 | + A_idx_value_var.name = "A_idx_value" |
| 375 | + |
| 376 | + I_value_var = I_rv.type() |
| 377 | + I_value_var.name = "I_value" |
| 378 | + |
| 379 | + A_idx_logps = joint_logp( |
| 380 | + (A_idx, I_rv), |
| 381 | + rvs_to_values={A_idx: A_idx_value_var, I_rv: I_value_var}, |
| 382 | + rvs_to_transforms={}, |
| 383 | + rvs_to_total_sizes={}, |
| 384 | + ) |
| 385 | + A_idx_logp = at.add(*A_idx_logps) |
| 386 | + |
| 387 | + logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp) |
| 388 | + |
| 389 | + # The compiled graph should not contain any `RandomVariables` |
| 390 | + assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0]) |
| 391 | + |
| 392 | + decimals = select_by_precision(float64=6, float32=4) |
| 393 | + |
| 394 | + for i in range(10): |
| 395 | + bern_sp = sp.bernoulli(p) |
| 396 | + I_value = bern_sp.rvs(size=size).astype(I_rv.dtype) |
| 397 | + |
| 398 | + norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma) |
| 399 | + A_idx_value = norm_sp.rvs().astype(A_idx.dtype) |
| 400 | + |
| 401 | + exp_obs_logps = norm_sp.logpdf(A_idx_value) |
| 402 | + exp_obs_logps += bern_sp.logpmf(I_value) |
| 403 | + |
| 404 | + logp_vals = logp_vals_fn(A_idx_value, I_value) |
| 405 | + |
| 406 | + np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals) |
| 407 | + |
| 408 | + |
| 409 | +def test_logp_helper(): |
| 410 | + value = at.vector("value") |
| 411 | + x = pm.Normal.dist(0, 1) |
| 412 | + |
| 413 | + x_logp = pm.logp(x, value) |
| 414 | + np.testing.assert_almost_equal(x_logp.eval({value: [0, 1]}), sp.norm(0, 1).logpdf([0, 1])) |
| 415 | + |
| 416 | + x_logp = pm.logp(x, [0, 1]) |
| 417 | + np.testing.assert_almost_equal(x_logp.eval(), sp.norm(0, 1).logpdf([0, 1])) |
| 418 | + |
| 419 | + |
| 420 | +def test_logp_helper_derived_rv(): |
| 421 | + assert np.isclose( |
| 422 | + pm.logp(at.exp(pm.Normal.dist()), 5).eval(), |
| 423 | + pm.logp(pm.LogNormal.dist(), 5).eval(), |
| 424 | + ) |
| 425 | + |
| 426 | + |
| 427 | +def test_logp_helper_exceptions(): |
| 428 | + with pytest.raises(TypeError, match="When RV is not a pure distribution"): |
| 429 | + pm.logp(at.exp(pm.Normal.dist()), [1, 2]) |
| 430 | + |
| 431 | + with pytest.raises(NotImplementedError, match="PyMC could not infer logp of input variable"): |
| 432 | + pm.logp(at.cos(pm.Normal.dist()), 1) |
| 433 | + |
| 434 | + |
| 435 | +def test_model_unchanged_logprob_access(): |
| 436 | + # Issue #5007 |
| 437 | + with pm.Model() as model: |
| 438 | + a = pm.Normal("a") |
| 439 | + c = pm.Uniform("c", lower=a - 1, upper=1) |
| 440 | + |
| 441 | + original_inputs = set(pytensor.graph.graph_inputs([c])) |
| 442 | + # Extract model.logp |
| 443 | + model.logp() |
| 444 | + new_inputs = set(pytensor.graph.graph_inputs([c])) |
| 445 | + assert original_inputs == new_inputs |
| 446 | + |
| 447 | + |
| 448 | +def test_unexpected_rvs(): |
| 449 | + with pm.Model() as model: |
| 450 | + x = pm.Normal("x") |
| 451 | + y = pm.CustomDist("y", logp=lambda *args: x) |
| 452 | + |
| 453 | + with pytest.raises(ValueError, match="^Random variables detected in the logp graph"): |
| 454 | + model.logp() |
| 455 | + |
| 456 | + |
| 457 | +def test_hierarchical_logp(): |
| 458 | + """Make sure there are no random variables in a model's log-likelihood graph.""" |
| 459 | + with pm.Model() as m: |
| 460 | + x = pm.Uniform("x", lower=0, upper=1) |
| 461 | + y = pm.Uniform("y", lower=0, upper=x) |
| 462 | + |
| 463 | + logp_ancestors = list(ancestors([m.logp()])) |
| 464 | + ops = {a.owner.op for a in logp_ancestors if a.owner} |
| 465 | + assert len(ops) > 0 |
| 466 | + assert not any(isinstance(o, RandomVariable) for o in ops) |
| 467 | + assert m.rvs_to_values[x] in logp_ancestors |
| 468 | + assert m.rvs_to_values[y] in logp_ancestors |
| 469 | + |
| 470 | + |
| 471 | +def test_hierarchical_obs_logp(): |
| 472 | + obs = np.array([0.5, 0.4, 5, 2]) |
| 473 | + |
| 474 | + with pm.Model() as model: |
| 475 | + x = pm.Uniform("x", 0, 1, observed=obs) |
| 476 | + pm.Uniform("y", x, 2, observed=obs) |
| 477 | + |
| 478 | + logp_ancestors = list(ancestors([model.logp()])) |
| 479 | + ops = {a.owner.op for a in logp_ancestors if a.owner} |
| 480 | + assert len(ops) > 0 |
| 481 | + assert not any(isinstance(o, RandomVariable) for o in ops) |
| 482 | + |
| 483 | + |
| 484 | +def test_logprob_join_constant_shapes(): |
| 485 | + x = at.random.normal(size=5) |
| 486 | + y = at.random.normal(size=3) |
| 487 | + xy = at.join(x, y) |
| 488 | + xy_vv = at.vector("xy_vv") |
| 489 | + |
| 490 | + xy_logp = pm.logp(xy, xy_vv) |
| 491 | + # This is what Aeppl does not do! |
| 492 | + assert_no_rvs(xy_logp) |
| 493 | + |
| 494 | + f = pytensor.function([xy_vv], xy_logp) |
| 495 | + np.testing.assert_array_equal(f(np.zeros(8)), sp.norm.logpdf(np.zeros(8))) |
0 commit comments