Skip to content

Commit 35ccca5

Browse files
authored
Fix error in dataset_to_point_list when chain, draw are not the leading dims (#7180)
1 parent 9bf2190 commit 35ccca5

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

pymc/util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,10 @@ def dataset_to_point_list(
249249
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
250250
num_sample_dims = len(sample_dims)
251251
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
252+
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
252253
stacked_dict = {
253-
vn: da.transpose(*sample_dims, ...).values.reshape((-1, *da.shape[num_sample_dims:]))
254-
for vn, da in ds.items()
254+
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
255+
for vn, da in transposed_dict.items()
255256
}
256257
points = [
257258
{vn: stacked_dict[vn][i, ...] for vn in var_names}

tests/test_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ def test_dataset_to_point_list(input_type):
170170
assert isinstance(pl[0]["A"], np.ndarray)
171171

172172

173+
def test_transposed_dataset_to_point_list():
174+
ds = xarray.Dataset()
175+
ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain"))
176+
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
177+
assert isinstance(pl, list)
178+
assert len(pl) == 6
179+
assert isinstance(pl[0], dict)
180+
assert isinstance(pl[0]["A"], np.ndarray)
181+
182+
173183
def test_dataset_to_point_list_str_key():
174184
# Check that non-str keys are caught
175185
ds = xarray.Dataset()

0 commit comments

Comments
 (0)