Skip to content

Commit 2a9e86c

Browse files
authored
Fix error when passing coords and dims in sampling_jax (#5983)
* Use 'partial' to construct inference data (#5932) * Add test cases for dims & coords * Comment on use of 'partial' * Extend docstring and pass posterior after partial
1 parent 78da937 commit 2a9e86c

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

pymc/sampling_jax.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def sample_blackjax_nuts(
254254
idata_kwargs : dict, optional
255255
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
256256
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
257-
not be included in the returned object.
257+
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
258+
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
259+
in ``idata_kwargs``.
258260
259261
Returns
260262
-------
@@ -365,16 +367,17 @@ def sample_blackjax_nuts(
365367
}
366368

367369
posterior = mcmc_samples
368-
az_trace = az.from_dict(
369-
posterior=posterior,
370+
# Use 'partial' to set default arguments before passing 'idata_kwargs'
371+
to_trace = partial(
372+
az.from_dict,
370373
log_likelihood=log_likelihood,
371374
observed_data=find_observations(model),
372375
constant_data=find_constants(model),
373376
coords=coords,
374377
dims=dims,
375378
attrs=make_attrs(attrs, library=blackjax),
376-
**idata_kwargs,
377379
)
380+
az_trace = to_trace(posterior=posterior, **idata_kwargs)
378381

379382
return az_trace
380383

@@ -431,7 +434,9 @@ def sample_numpyro_nuts(
431434
idata_kwargs : dict, optional
432435
Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value
433436
for the ``log_likelihood`` key to indicate that the pointwise log likelihood should
434-
not be included in the returned object.
437+
not be included in the returned object. Values for ``observed_data``, ``constant_data``,
438+
``coords``, and ``dims`` are inferred from the ``model`` argument if not provided
439+
in ``idata_kwargs``.
435440
nuts_kwargs: dict, optional
436441
Keyword arguments for :func:`numpyro.infer.NUTS`.
437442
@@ -560,16 +565,17 @@ def sample_numpyro_nuts(
560565
}
561566

562567
posterior = mcmc_samples
563-
az_trace = az.from_dict(
564-
posterior=posterior,
568+
# Use 'partial' to set default arguments before passing 'idata_kwargs'
569+
to_trace = partial(
570+
az.from_dict,
565571
log_likelihood=log_likelihood,
566572
observed_data=find_observations(model),
567573
constant_data=find_constants(model),
568574
sample_stats=_sample_stats_to_xarray(pmap_numpyro),
569575
coords=coords,
570576
dims=dims,
571577
attrs=make_attrs(attrs, library=numpyro),
572-
**idata_kwargs,
573578
)
579+
az_trace = to_trace(posterior=posterior, **idata_kwargs)
574580

575581
return az_trace

pymc/tests/test_sampling_jax.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@ def test_get_jaxified_logp():
153153
assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))
154154

155155

156+
@pytest.fixture
157+
def model_test_idata_kwargs(scope="module"):
158+
with pm.Model(coords={"x_coord": ["a", "b"], "x_coord2": [1, 2]}) as m:
159+
x = pm.Normal("x", shape=(2,), dims=["x_coord"])
160+
y = pm.Normal("y", x, observed=[0, 0])
161+
pm.ConstantData("constantdata", [1, 2, 3])
162+
pm.MutableData("mutabledata", 2)
163+
return m
164+
165+
156166
@pytest.mark.parametrize(
157167
"sampler",
158168
[
@@ -165,15 +175,17 @@ def test_get_jaxified_logp():
165175
[
166176
dict(),
167177
dict(log_likelihood=False),
178+
# Overwrite models coords
179+
dict(coords={"x_coord": ["x1", "x2"]}),
180+
# Overwrite dims from dist specification in model
181+
dict(dims={"x": ["x_coord2"]}),
182+
# Overwrite both coords and dims
183+
dict(coords={"x_coord3": ["A", "B"]}, dims={"x": ["x_coord3"]}),
168184
],
169185
)
170186
@pytest.mark.parametrize("postprocessing_backend", [None, "cpu"])
171-
def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
172-
with pm.Model() as m:
173-
x = pm.Normal("x")
174-
y = pm.Normal("y", x, observed=0)
175-
pm.ConstantData("constantdata", [1, 2, 3])
176-
pm.MutableData("mutabledata", 2)
187+
def test_idata_kwargs(model_test_idata_kwargs, sampler, idata_kwargs, postprocessing_backend):
188+
with model_test_idata_kwargs:
177189
idata = sampler(
178190
tune=50,
179191
draws=50,
@@ -189,6 +201,12 @@ def test_idata_kwargs(sampler, idata_kwargs, postprocessing_backend):
189201
else:
190202
assert "log_likelihood" not in idata
191203

204+
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.RV_dims)["x"][0]
205+
assert idata.posterior.x.dims[-1] == x_dim_expected
206+
207+
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected]
208+
assert list(x_coords_expected) == list(idata.posterior.x.coords[x_dim_expected].values)
209+
192210

193211
def test_get_batched_jittered_initial_points():
194212
with pm.Model() as model:

0 commit comments

Comments
 (0)