Skip to content

Commit 47c62d9

Browse files
committed
Change underlying representation of Model._coords to ndarray
1 parent b6521f2 commit 47c62d9

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

pymc/model.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
UNSET,
7676
WithMemoization,
7777
_add_future_warning_tag,
78+
_as_coord_vals,
7879
get_transformed_name,
7980
get_value_vars_from_user_vars,
8081
get_var_name,
@@ -986,8 +987,19 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
986987

987988
@property
988989
def coords(self) -> Dict[str, Union[Tuple, None]]:
989-
"""Coordinate values for model dimensions."""
990-
return self._coords
990+
"""Coordinate values (tuple) for model dimensions."""
991+
return {
992+
cname: tuple(cvals) if cvals is not None else None
993+
for cname, cvals in self._coords.items()
994+
}
995+
996+
@property
997+
def coords_typed(self) -> Dict[str, Union[np.ndarray, None]]:
998+
"""Coordinate values (numpy array) for model dimensions."""
999+
return {
1000+
cname: cvals.copy() if cvals is not None else None
1001+
for cname, cvals in self._coords.items()
1002+
}
9911003

9921004
@property
9931005
def dim_lengths(self) -> Dict[str, Variable]:
@@ -1047,9 +1059,8 @@ def add_coord(
10471059
f"Either `values` or `length` must be specified for the '{name}' dimension."
10481060
)
10491061
if values is not None:
1050-
# Conversion to a tuple ensures that the coordinate values are immutable.
1051-
# Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
1052-
values = tuple(values)
1062+
# Conversion to numpy array to ensure coord vals are 1-dim
1063+
values = _as_coord_vals(values)
10531064
if name in self.coords:
10541065
if not np.array_equal(values, self.coords[name]):
10551066
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
@@ -1107,7 +1118,7 @@ def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] =
11071118
actual=len_cvals,
11081119
expected=new_length,
11091120
)
1110-
self._coords[name] = tuple(coord_values)
1121+
self._coords[name] = _as_coord_vals(coord_values)
11111122
self.dim_lengths[name].set_value(new_length)
11121123
return
11131124

@@ -1278,8 +1289,7 @@ def set_data(
12781289
actual=len(new_coords),
12791290
expected=new_length,
12801291
)
1281-
# store it as tuple for immutability as in add_coord
1282-
self._coords[dname] = tuple(new_coords)
1292+
self._coords[dname] = _as_coord_vals(new_coords)
12831293

12841294
shared_object.set_value(values)
12851295

pymc/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,3 +510,11 @@ def _add_future_warning_tag(var) -> None:
510510
for k, v in old_tag.__dict__.items():
511511
new_tag.__dict__.setdefault(k, v)
512512
var.tag = new_tag
513+
514+
def _as_coord_vals(values: Sequence) -> np.ndarray:
515+
"""Coerce a sequence coordinate values into a 1-dim array of values"""
516+
arr = np.array(values)
517+
if arr.ndim > 1:
518+
arr = np.empty(len(values), dtype="O")
519+
arr[:] = values
520+
return arr

0 commit comments

Comments
 (0)