Skip to content

Commit 6f8f9ee

Browse files
authored
Extend dataset_to_point_dict to accept both dataset and dict of dataarray (#7097)
1 parent 2051d0b commit 6f8f9ee

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

pymc/util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,18 +239,18 @@ def enhanced(*args, **kwargs):
239239

240240

241241
def dataset_to_point_list(
242-
ds: xarray.Dataset, sample_dims: Sequence[str]
242+
ds: Union[xarray.Dataset, dict[str, xarray.DataArray]], sample_dims: Sequence[str]
243243
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
244244
# All keys of the dataset must be a str
245-
var_names = list(ds.keys())
245+
var_names = cast(List[str], list(ds.keys()))
246246
for vn in var_names:
247247
if not isinstance(vn, str):
248248
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
249249
num_sample_dims = len(sample_dims)
250-
stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims}
251-
ds = ds.transpose(*sample_dims, ...)
250+
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
252251
stacked_dict = {
253-
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items()
252+
vn: da.transpose(*sample_dims, ...).values.reshape((-1, *da.shape[num_sample_dims:]))
253+
for vn, da in ds.items()
254254
}
255255
points = [
256256
{vn: stacked_dict[vn][i, ...] for vn in var_names}

tests/test_util.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,23 @@ def fn(a=UNSET):
156156
assert "a=UNSET" in captured.out
157157

158158

159-
def test_dataset_to_point_list():
160-
ds = xarray.Dataset()
159+
@pytest.mark.parametrize("input_type", ("dict", "Dataset"))
160+
def test_dataset_to_point_list(input_type):
161+
if input_type == "dict":
162+
ds = {}
163+
elif input_type == "Dataset":
164+
ds = xarray.Dataset()
161165
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
162166
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
163167
assert isinstance(pl, list)
164168
assert len(pl) == 6
165169
assert isinstance(pl[0], dict)
166170
assert isinstance(pl[0]["A"], np.ndarray)
167171

172+
173+
def test_dataset_to_point_list_str_key():
168174
# Check that non-str keys are caught
175+
ds = xarray.Dataset()
169176
ds[3] = xarray.DataArray([1, 2, 3])
170177
with pytest.raises(ValueError, match="must be str"):
171178
dataset_to_point_list(ds, sample_dims=["chain", "draw"])

0 commit comments

Comments
 (0)