|
36 | 36 | from pymc.pytensorf import compile_pymc
|
37 | 37 | from pymc.sampling.forward import (
|
38 | 38 | compile_forward_sampling_function,
|
| 39 | + get_constant_coords, |
39 | 40 | get_vars_in_point_list,
|
40 | 41 | observed_dependent_deterministics,
|
41 | 42 | )
|
@@ -428,6 +429,45 @@ def test_mutable_coords_volatile(self):
|
428 | 429 | "offsets",
|
429 | 430 | }
|
430 | 431 |
|
| 432 | + def test_length_coords_volatile(self): |
| 433 | + with pm.Model() as model: |
| 434 | + model.add_coord("trial", length=3) |
| 435 | + x = pm.Normal("x", dims="trial") |
| 436 | + y = pm.Deterministic("y", x.mean()) |
| 437 | + |
| 438 | + # Same coord length -- `x` is not volatile |
| 439 | + trace_same_len = az_from_dict( |
| 440 | + posterior={"x": [[[np.pi] * 3]]}, |
| 441 | + coords={"trial": range(3)}, |
| 442 | + dims={"x": ["trial"]}, |
| 443 | + ) |
| 444 | + with model: |
| 445 | + pp_same_len = pm.sample_posterior_predictive( |
| 446 | + trace_same_len, var_names=["y"] |
| 447 | + ).posterior_predictive |
| 448 | + assert pp_same_len["y"] == np.pi |
| 449 | + |
| 450 | + # Coord length changed -- `x` is volatile |
| 451 | + trace_diff_len = az_from_dict( |
| 452 | + posterior={"x": [[[np.pi] * 2]]}, |
| 453 | + coords={"trial": range(2)}, |
| 454 | + dims={"x": ["trial"]}, |
| 455 | + ) |
| 456 | + with model: |
| 457 | + pp_diff_len = pm.sample_posterior_predictive( |
| 458 | + trace_diff_len, var_names=["y"] |
| 459 | + ).posterior_predictive |
| 460 | + assert pp_diff_len["y"] != np.pi |
| 461 | + |
| 462 | + # Changing the dim length on the model itself |
| 463 | + # -- `x` is volatile because trace has same len as original model |
| 464 | + model.set_dim("trial", new_length=7) |
| 465 | + with model: |
| 466 | + pp_diff_len_model_set = pm.sample_posterior_predictive( |
| 467 | + trace_same_len, var_names=["y"] |
| 468 | + ).posterior_predictive |
| 469 | + assert pp_diff_len_model_set["y"] != np.pi |
| 470 | + |
431 | 471 |
|
432 | 472 | class TestSamplePPC:
|
433 | 473 | def test_normal_scalar(self):
|
@@ -1670,6 +1710,20 @@ def test_Triangular(
|
1670 | 1710 | assert prior["target"].shape == (prior_samples, *shape)
|
1671 | 1711 |
|
1672 | 1712 |
|
| 1713 | +def test_get_constant_coords(): |
| 1714 | + with pm.Model() as model: |
| 1715 | + model.add_coord("length_coord", length=1) |
| 1716 | + model.add_coord("value_coord", values=(3,)) |
| 1717 | + |
| 1718 | + trace_coords_same = {"length_coord": np.array([0]), "value_coord": np.array([3])} |
| 1719 | + constant_coords_same = get_constant_coords(trace_coords_same, model) |
| 1720 | + assert constant_coords_same == {"length_coord", "value_coord"} |
| 1721 | + |
| 1722 | + trace_coords_diff = {"length_coord": np.array([0, 1]), "value_coord": np.array([4])} |
| 1723 | + constant_coords_diff = get_constant_coords(trace_coords_diff, model) |
| 1724 | + assert constant_coords_diff == set() |
| 1725 | + |
| 1726 | + |
1673 | 1727 | def test_get_vars_in_point_list():
|
1674 | 1728 | with pm.Model() as modelA:
|
1675 | 1729 | pm.Normal("a", 0, 1)
|
|
0 commit comments