File tree Expand file tree Collapse file tree 3 files changed +11
-0
lines changed Expand file tree Collapse file tree 3 files changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -671,6 +671,7 @@ def sample_jax_nuts(
671
671
672
672
attrs = {
673
673
"sampling_time" : (tic2 - tic1 ).total_seconds (),
674
+ "tuning_steps" : tune ,
674
675
}
675
676
676
677
coords , dims = coords_and_dims_for_inferencedata (model )
@@ -680,6 +681,7 @@ def sample_jax_nuts(
680
681
coords .update (idata_kwargs .pop ("coords" ))
681
682
if "dims" in idata_kwargs :
682
683
dims .update (idata_kwargs .pop ("dims" ))
684
+
683
685
# Use 'partial' to set default arguments before passing 'idata_kwargs'
684
686
to_trace = partial (
685
687
az .from_dict ,
@@ -690,6 +692,7 @@ def sample_jax_nuts(
690
692
coords = coords ,
691
693
dims = dims ,
692
694
attrs = make_attrs (attrs , library = library ),
695
+ posterior_attrs = make_attrs (attrs , library = library ),
693
696
)
694
697
az_trace = to_trace (posterior = mcmc_samples , ** idata_kwargs )
695
698
Original file line number Diff line number Diff line change @@ -336,6 +336,7 @@ def _sample_external_nuts(
336
336
attrs = make_attrs (
337
337
{
338
338
"sampling_time" : t_sample ,
339
+ "tuning_steps" : tune ,
339
340
},
340
341
library = nutpie ,
341
342
)
Original file line number Diff line number Diff line change @@ -44,6 +44,10 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
44
44
idata1 = sample (** kwargs )
45
45
idata2 = sample (** kwargs )
46
46
47
+ reference_kwargs = kwargs .copy ()
48
+ reference_kwargs ["nuts_sampler" ] = "pymc"
49
+ idata_reference = sample (** reference_kwargs )
50
+
47
51
warns = {
48
52
(warn .category , warn .message .args [0 ])
49
53
for warn in recwarn
@@ -64,8 +68,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
64
68
assert "L" in idata1 .observed_data
65
69
assert idata1 .posterior .chain .size == 2
66
70
assert idata1 .posterior .draw .size == 500
71
+ assert idata1 .posterior .tuning_steps == 500
67
72
np .testing .assert_array_equal (idata1 .posterior .x , idata2 .posterior .x )
68
73
74
+ assert idata_reference .posterior .attrs .keys () == idata1 .posterior .attrs .keys ()
75
+
69
76
70
77
def test_step_args ():
71
78
with Model () as model :
You can’t perform that action at this time.
0 commit comments