Skip to content

Commit c9c07bb

Browse files
Manage coord values as tuples
And always make them numpy arrays for the InferenceData conversion. Closes #5043
1 parent a44f739 commit c9c07bb

File tree

4 files changed

+37
-11
lines changed

4 files changed

+37
-11
lines changed

pymc/backends/arviz.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,12 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
222222
aelem = arbitrary_element(get_from)
223223
self.ndraws = aelem.shape[0]
224224

225-
self.coords = {} if coords is None else coords
226-
if hasattr(self.model, "coords"):
227-
self.coords = {**self.model.coords, **self.coords}
228-
self.coords = {key: value for key, value in self.coords.items() if value is not None}
225+
self.coords = {**self.model.coords, **(coords or {})}
226+
self.coords = {
227+
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
228+
for cname, cvals in self.coords.items()
229+
if cvals is not None
230+
}
229231

230232
self.dims = {} if dims is None else dims
231233
if hasattr(self.model, "RV_dims"):

pymc/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
871871
return self._RV_dims
872872

873873
@property
874-
def coords(self) -> Dict[str, Union[Sequence, None]]:
874+
def coords(self) -> Dict[str, Union[Tuple, None]]:
875875
"""Coordinate values for model dimensions."""
876876
return self._coords
877877

@@ -1096,8 +1096,12 @@ def add_coord(
10961096
raise ValueError(
10971097
f"The `length` passed for the '{name}' coord must be an Aesara Variable or None."
10981098
)
1099+
if values is not None:
1100+
# Conversion to a tuple ensures that the coordinate values are immutable.
1101+
# Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
1102+
values = tuple(values)
10991103
if name in self.coords:
1100-
if not values.equals(self.coords[name]):
1104+
if not np.array_equal(values, self.coords[name]):
11011105
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
11021106
else:
11031107
self._coords[name] = values

pymc/tests/test_data_container.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,12 @@ def test_explicit_coords(self):
287287
pm.Data("observations", data, dims=("rows", "columns"))
288288

289289
assert "rows" in pmodel.coords
290-
assert pmodel.coords["rows"] == ["R1", "R2", "R3", "R4", "R5"]
290+
assert pmodel.coords["rows"] == ("R1", "R2", "R3", "R4", "R5")
291291
assert "rows" in pmodel.dim_lengths
292292
assert isinstance(pmodel.dim_lengths["rows"], ScalarSharedVariable)
293293
assert pmodel.dim_lengths["rows"].eval() == 5
294294
assert "columns" in pmodel.coords
295-
assert pmodel.coords["columns"] == ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]
295+
assert pmodel.coords["columns"] == ("C1", "C2", "C3", "C4", "C5", "C6", "C7")
296296
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
297297
assert "columns" in pmodel.dim_lengths
298298
assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable)

pymc/tests/test_idata_conversion.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,10 @@ def test_constant_data_coords_issue_5046(self):
604604

605605
def test_issue_5043_autoconvert_coord_values(self):
606606
coords = {
607-
"city": ("Bonn", "Berlin"),
607+
"city": pd.Series(["Bonn", "Berlin"]),
608608
}
609609
with pm.Model(coords=coords) as pmodel:
610+
# The model tracks coord values as (immutable) tuples
610611
assert isinstance(pmodel.coords["city"], tuple)
611612
pm.Normal("x", dims="city")
612613
mtrace = pm.sample(
@@ -617,9 +618,28 @@ def test_issue_5043_autoconvert_coord_values(self):
617618
tune=7,
618619
draws=15,
619620
)
621+
# The converter must convert coord values them to numpy arrays
622+
# because tuples as coordinate values causes problems with xarray.
620623
converter = InferenceDataConverter(trace=mtrace)
621-
with pytest.raises(ValueError, match="same length as the number of data dimensions"):
622-
converter.to_inference_data()
624+
assert isinstance(converter.coords["city"], np.ndarray)
625+
converter.to_inference_data()
626+
627+
# We're not automatically converting things other than tuple,
628+
# so advanced use cases remain supported at the InferenceData level.
629+
# They just can't be used in the model construction already.
630+
converter = InferenceDataConverter(
631+
trace=mtrace,
632+
coords={
633+
"city": pd.MultiIndex.from_tuples(
634+
[
635+
("Bonn", 53111),
636+
("Berlin", 10178),
637+
],
638+
names=["name", "zipcode"],
639+
)
640+
},
641+
)
642+
assert isinstance(converter.coords["city"], pd.MultiIndex)
623643

624644

625645
class TestPyMCWarmupHandling:

0 commit comments

Comments
 (0)