Skip to content

Commit 253513b

Browse files
committed
Make zip strict in apply_function_over_dataset
1 parent 90f20a2 commit 253513b

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def apply_function_over_dataset(
656656
for idx in indices:
657657
out = fn(posterior_pts[idx])
658658
fn.f.trust_input = True # If we arrive here the dtypes are valid
659-
for var_name, val in zip(output_var_names, out):
659+
for var_name, val in zip(output_var_names, out, strict=True):
660660
out_dict.insert(var_name, val, idx)
661661

662662
progress.advance(task)

tests/stats/test_log_density.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,15 @@ def test_compilation_kwargs(self):
184184
Normal("y", x, observed=[0, 1, 2])
185185

186186
idata = InferenceData(posterior=dict_to_dataset({"x": np.arange(100).reshape(4, 25)}))
187-
with patch("pymc.model.core.compile_pymc") as patched_compile_pymc:
188-
compute_log_prior(idata, compile_kwargs={"mode": "JAX"})
189-
compute_log_likelihood(idata, compile_kwargs={"mode": "NUMBA"})
187+
with (
188+
# apply_function_over_dataset fails with patched `compile_pymc`
189+
patch("pymc.stats.log_density.apply_function_over_dataset"),
190+
patch("pymc.model.core.compile_pymc") as patched_compile_pymc,
191+
):
192+
compute_log_prior(idata, compile_kwargs={"mode": "JAX"}, extend_inferencedata=False)
193+
compute_log_likelihood(
194+
idata, compile_kwargs={"mode": "NUMBA"}, extend_inferencedata=False
195+
)
190196
assert len(patched_compile_pymc.call_args_list) == 2
191197
assert patched_compile_pymc.call_args_list[0].kwargs["mode"] == "JAX"
192198
assert patched_compile_pymc.call_args_list[1].kwargs["mode"] == "NUMBA"

0 commit comments

Comments
 (0)