Skip to content

Commit 5bc6801

Browse files
osyukselricardoV94
andauthored
Add more idata attributes for JAX samplers (#7360)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent bbd5739 commit 5bc6801

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

pymc/sampling/jax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ def sample_jax_nuts(
671671

672672
attrs = {
673673
"sampling_time": (tic2 - tic1).total_seconds(),
674+
"tuning_steps": tune,
674675
}
675676

676677
coords, dims = coords_and_dims_for_inferencedata(model)
@@ -680,6 +681,7 @@ def sample_jax_nuts(
680681
coords.update(idata_kwargs.pop("coords"))
681682
if "dims" in idata_kwargs:
682683
dims.update(idata_kwargs.pop("dims"))
684+
683685
# Use 'partial' to set default arguments before passing 'idata_kwargs'
684686
to_trace = partial(
685687
az.from_dict,
@@ -690,6 +692,7 @@ def sample_jax_nuts(
690692
coords=coords,
691693
dims=dims,
692694
attrs=make_attrs(attrs, library=library),
695+
posterior_attrs=make_attrs(attrs, library=library),
693696
)
694697
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)
695698

pymc/sampling/mcmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def _sample_external_nuts(
336336
attrs = make_attrs(
337337
{
338338
"sampling_time": t_sample,
339+
"tuning_steps": tune,
339340
},
340341
library=nutpie,
341342
)

tests/sampling/test_mcmc_external.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
4444
idata1 = sample(**kwargs)
4545
idata2 = sample(**kwargs)
4646

47+
reference_kwargs = kwargs.copy()
48+
reference_kwargs["nuts_sampler"] = "pymc"
49+
idata_reference = sample(**reference_kwargs)
50+
4751
warns = {
4852
(warn.category, warn.message.args[0])
4953
for warn in recwarn
@@ -64,8 +68,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
6468
assert "L" in idata1.observed_data
6569
assert idata1.posterior.chain.size == 2
6670
assert idata1.posterior.draw.size == 500
71+
assert idata1.posterior.tuning_steps == 500
6772
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)
6873

74+
assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()
75+
6976

7077
def test_step_args():
7178
with Model() as model:

0 commit comments

Comments
 (0)