Skip to content

Commit 1c532f8

Browse files
JasonTamricardoV94
andauthored
Do not consider dims without coords volatile if length has not changed (#7381)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent f719796 commit 1c532f8

File tree

2 files changed

+79
-10
lines changed

2 files changed

+79
-10
lines changed

pymc/sampling/forward.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
)
4040
from pytensor.graph.fg import FunctionGraph
4141
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
42-
from pytensor.tensor.sharedvar import SharedVariable
42+
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
43+
from pytensor.tensor.variable import TensorConstant
4344
from rich.console import Console
4445
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4546
from rich.theme import Theme
@@ -73,6 +74,28 @@
7374
_log = logging.getLogger(__name__)
7475

7576

77+
def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> set:
78+
"""Get the set of coords that have remained constant between the trace and model"""
79+
constant_coords = set()
80+
for dim, coord in trace_coords.items():
81+
current_coord = model.coords.get(dim, None)
82+
current_length = model.dim_lengths.get(dim, None)
83+
if isinstance(current_length, TensorSharedVariable):
84+
current_length = current_length.get_value()
85+
elif isinstance(current_length, TensorConstant):
86+
current_length = current_length.data
87+
if (
88+
current_coord is not None
89+
and len(coord) == len(current_coord)
90+
and np.all(coord == current_coord)
91+
) or (
92+
# Coord was defined without values (only length)
93+
current_coord is None and len(coord) == current_length
94+
):
95+
constant_coords.add(dim)
96+
return constant_coords
97+
98+
7699
def get_vars_in_point_list(trace, model):
77100
"""Get the list of Variable instances in the model that have values stored in the trace."""
78101
if not isinstance(trace, MultiTrace):
@@ -789,15 +812,7 @@ def sample_posterior_predictive(
789812
stacklevel=2,
790813
)
791814

792-
constant_coords = set()
793-
for dim, coord in trace_coords.items():
794-
current_coord = model.coords.get(dim, None)
795-
if (
796-
current_coord is not None
797-
and len(coord) == len(current_coord)
798-
and np.all(coord == current_coord)
799-
):
800-
constant_coords.add(dim)
815+
constant_coords = get_constant_coords(trace_coords, model)
801816

802817
if var_names is not None:
803818
vars_ = [model[x] for x in var_names]

tests/sampling/test_forward.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pymc.pytensorf import compile_pymc
3737
from pymc.sampling.forward import (
3838
compile_forward_sampling_function,
39+
get_constant_coords,
3940
get_vars_in_point_list,
4041
observed_dependent_deterministics,
4142
)
@@ -428,6 +429,45 @@ def test_mutable_coords_volatile(self):
428429
"offsets",
429430
}
430431

432+
def test_length_coords_volatile(self):
433+
with pm.Model() as model:
434+
model.add_coord("trial", length=3)
435+
x = pm.Normal("x", dims="trial")
436+
y = pm.Deterministic("y", x.mean())
437+
438+
# Same coord length -- `x` is not volatile
439+
trace_same_len = az_from_dict(
440+
posterior={"x": [[[np.pi] * 3]]},
441+
coords={"trial": range(3)},
442+
dims={"x": ["trial"]},
443+
)
444+
with model:
445+
pp_same_len = pm.sample_posterior_predictive(
446+
trace_same_len, var_names=["y"]
447+
).posterior_predictive
448+
assert pp_same_len["y"] == np.pi
449+
450+
# Coord length changed -- `x` is volatile
451+
trace_diff_len = az_from_dict(
452+
posterior={"x": [[[np.pi] * 2]]},
453+
coords={"trial": range(2)},
454+
dims={"x": ["trial"]},
455+
)
456+
with model:
457+
pp_diff_len = pm.sample_posterior_predictive(
458+
trace_diff_len, var_names=["y"]
459+
).posterior_predictive
460+
assert pp_diff_len["y"] != np.pi
461+
462+
# Changing the dim length on the model itself
463+
# -- `x` is volatile because trace has same len as original model
464+
model.set_dim("trial", new_length=7)
465+
with model:
466+
pp_diff_len_model_set = pm.sample_posterior_predictive(
467+
trace_same_len, var_names=["y"]
468+
).posterior_predictive
469+
assert pp_diff_len_model_set["y"] != np.pi
470+
431471

432472
class TestSamplePPC:
433473
def test_normal_scalar(self):
@@ -1670,6 +1710,20 @@ def test_Triangular(
16701710
assert prior["target"].shape == (prior_samples, *shape)
16711711

16721712

1713+
def test_get_constant_coords():
1714+
with pm.Model() as model:
1715+
model.add_coord("length_coord", length=1)
1716+
model.add_coord("value_coord", values=(3,))
1717+
1718+
trace_coords_same = {"length_coord": np.array([0]), "value_coord": np.array([3])}
1719+
constant_coords_same = get_constant_coords(trace_coords_same, model)
1720+
assert constant_coords_same == {"length_coord", "value_coord"}
1721+
1722+
trace_coords_diff = {"length_coord": np.array([0, 1]), "value_coord": np.array([4])}
1723+
constant_coords_diff = get_constant_coords(trace_coords_diff, model)
1724+
assert constant_coords_diff == set()
1725+
1726+
16731727
def test_get_vars_in_point_list():
16741728
with pm.Model() as modelA:
16751729
pm.Normal("a", 0, 1)

0 commit comments

Comments
 (0)