Skip to content

Commit 7677f0c

Browse files
committed
Change use of Model.coords -> Model.coords_typed throughout pymc
1 parent 47c62d9 commit 7677f0c

File tree

10 files changed

+62
-52
lines changed

10 files changed

+62
-52
lines changed

pymc/backends/arviz.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from pymc.model import Model, modelcontext
4141
from pymc.pytensorf import extract_obs_data
42-
from pymc.util import get_default_varnames
42+
from pymc.util import _as_coord_vals, get_default_varnames
4343

4444
if TYPE_CHECKING:
4545
from pymc.backends.base import MultiTrace # pylint: disable=invalid-name
@@ -216,15 +216,15 @@ def __init__(
216216
" one of trace, prior, posterior_predictive or predictions."
217217
)
218218

219-
# Make coord types more rigid
220-
untyped_coords: Dict[str, Optional[Sequence[Any]]] = {**self.model.coords}
219+
coords_typed = self.model.coords_typed
221220
if coords:
222-
untyped_coords.update(coords)
223-
self.coords = {
224-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
225-
for cname, cvals in untyped_coords.items()
226-
if cvals is not None
221+
coords_typed.update(_as_coord_vals(coords))
222+
coords_typed = {
223+
cname: cvals_typed
224+
for cname, cvals_typed in coords_typed.items()
225+
if cvals_typed is not None
227226
}
227+
self.coords = coords_typed
228228

229229
self.dims = {} if dims is None else dims
230230
model_dims = {k: list(v) for k, v in self.model.named_vars_to_dims.items()}

pymc/backends/mcbackend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ def make_runmeta_and_point_fn(
230230
sample_stats.append(svar)
231231

232232
coordinates = [
233-
mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals)))
234-
for dname, cvals in model.coords.items()
235-
if cvals is not None
233+
mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(cvals_typed))
234+
for dname, cvals_typed in model.coords_typed.items()
235+
if cvals_typed is not None
236236
]
237237
meta = mcb.RunMeta(
238238
rid=hagelkorn.random(),

pymc/distributions/bound.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def __new__(
195195

196196
if dims is not None:
197197
model = modelcontext(None)
198-
if dims in model.coords:
199-
dim_obj = np.asarray(model.coords[dims])
198+
if dims in model.coords_typed:
199+
dim_obj = model.coords_typed[dims]
200200
size = dim_obj.shape
201201
else:
202202
raise ValueError("Given dims do not exist in model coordinates.")

pymc/model.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def coords(self) -> Dict[str, Union[Tuple, None]]:
992992
cname: tuple(cvals) if cvals is not None else None
993993
for cname, cvals in self._coords.items()
994994
}
995-
995+
996996
@property
997997
def coords_typed(self) -> Dict[str, Union[np.ndarray, None]]:
998998
"""Coordinate values (numpy array) for model dimensions."""
@@ -1014,20 +1014,20 @@ def shape_from_dims(self, dims):
10141014
if len(set(dims)) != len(dims):
10151015
raise ValueError("Can not contain the same dimension name twice.")
10161016
for dim in dims:
1017-
if dim not in self.coords:
1017+
if dim not in self.coords_typed:
10181018
raise ValueError(
10191019
f"Unknown dimension name '{dim}'. All dimension "
10201020
"names must be specified in the `coords` "
10211021
"argument of the model or through a pm.Data "
10221022
"variable."
10231023
)
1024-
shape.extend(np.shape(self.coords[dim]))
1024+
shape.extend(self.coords_typed[dim].shape)
10251025
return tuple(shape)
10261026

10271027
def add_coord(
10281028
self,
10291029
name: str,
1030-
values: Optional[Sequence] = None,
1030+
values: Optional[Union[Sequence, np.ndarray]] = None,
10311031
mutable: bool = False,
10321032
*,
10331033
length: Optional[Union[int, Variable]] = None,
@@ -1061,8 +1061,8 @@ def add_coord(
10611061
if values is not None:
10621062
# Conversion to numpy array to ensure coord vals are 1-dim
10631063
values = _as_coord_vals(values)
1064-
if name in self.coords:
1065-
if not np.array_equal(values, self.coords[name]):
1064+
if name in self.coords_typed:
1065+
if not np.array_equal(_as_coord_vals(values), self.coords_typed[name]):
10661066
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
10671067
if length is not None and not isinstance(length, (int, Variable)):
10681068
raise ValueError(
@@ -1080,7 +1080,7 @@ def add_coord(
10801080

10811081
def add_coords(
10821082
self,
1083-
coords: Dict[str, Optional[Sequence]],
1083+
coords: Dict[str, Optional[Union[Sequence, np.ndarray]]],
10841084
*,
10851085
lengths: Optional[Dict[str, Optional[Union[int, Variable]]]] = None,
10861086
):
@@ -1092,7 +1092,9 @@ def add_coords(
10921092
for name, values in coords.items():
10931093
self.add_coord(name, values, length=lengths.get(name, None))
10941094

1095-
def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] = None):
1095+
def set_dim(
1096+
self, name: str, new_length: int, coord_values: Optional[Union[Sequence, np.ndarray]] = None
1097+
):
10961098
"""Update a mutable dimension.
10971099
10981100
Parameters
@@ -1106,7 +1108,7 @@ def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] =
11061108
"""
11071109
if not isinstance(self.dim_lengths[name], ScalarSharedVariable):
11081110
raise ValueError(f"The dimension '{name}' is immutable.")
1109-
if coord_values is None and self.coords.get(name, None) is not None:
1111+
if coord_values is None and self.coords_typed.get(name, None) is not None:
11101112
raise ValueError(
11111113
f"'{name}' has coord values. Pass `set_dim(..., coord_values=...)` to update them."
11121114
)
@@ -1162,7 +1164,7 @@ def set_data(
11621164
self,
11631165
name: str,
11641166
values: Dict[str, Optional[Sequence]],
1165-
coords: Optional[Dict[str, Sequence]] = None,
1167+
coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None,
11661168
):
11671169
"""Changes the values of a data variable in the model.
11681170
@@ -1192,7 +1194,7 @@ def set_data(
11921194
values = np.array(values)
11931195
values = convert_observed_data(values)
11941196
dims = self.named_vars_to_dims.get(name, None) or ()
1195-
coords = coords or {}
1197+
coords_untyped = coords or {}
11961198

11971199
if values.ndim != shared_object.ndim:
11981200
raise ValueError(
@@ -1203,19 +1205,21 @@ def set_data(
12031205
length_tensor = self.dim_lengths[dname]
12041206
old_length = length_tensor.eval()
12051207
new_length = values.shape[d]
1206-
original_coords = self.coords.get(dname, None)
1207-
new_coords = coords.get(dname, None)
1208+
original_coord_vals = self.coords_typed.get(dname, None)
1209+
new_coord_vals = coords_untyped.get(dname, None)
1210+
if new_coord_vals is not None:
1211+
new_coord_vals = _as_coord_vals(new_coord_vals)
12081212

12091213
length_changed = new_length != old_length
12101214

12111215
# Reject resizing if we already know that it would create shape problems.
12121216
# NOTE: If there are multiple pm.MutableData containers sharing this dim, but the user only
12131217
# changes the values for one of them, they will run into shape problems nonetheless.
12141218
if length_changed:
1215-
if original_coords is not None:
1216-
if new_coords is None:
1219+
if original_coord_vals is not None:
1220+
if new_coord_vals is None:
12171221
raise ValueError(
1218-
f"The '{name}' variable already had {len(original_coords)} coord values defined for "
1222+
f"The '{name}' variable already had {len(original_coord_vals)} coord values defined for "
12191223
f"its {dname} dimension. With the new values this dimension changes to length "
12201224
f"{new_length}, so new coord values for the {dname} dimension are required."
12211225
)
@@ -1279,17 +1283,17 @@ def set_data(
12791283
if isinstance(length_tensor, ScalarSharedVariable):
12801284
# The dimension is mutable, but was defined without being linked
12811285
# to a shared variable. This is allowed, but a little less robust.
1282-
self.set_dim(dname, new_length, coord_values=new_coords)
1286+
self.set_dim(dname, new_length, coord_values=new_coord_vals)
12831287

1284-
if new_coords is not None:
1288+
if new_coord_vals is not None:
12851289
# Update the registered coord values (also if they were None)
1286-
if len(new_coords) != new_length:
1290+
if len(new_coord_vals) != new_length:
12871291
raise ShapeError(
12881292
f"Length of new coordinate values for dimension '{dname}' does not match the provided values.",
1289-
actual=len(new_coords),
1293+
actual=len(new_coord_vals),
12901294
expected=new_length,
12911295
)
1292-
self._coords[dname] = _as_coord_vals(new_coords)
1296+
self._coords[dname] = _as_coord_vals(new_coord_vals)
12931297

12941298
shared_object.set_value(values)
12951299

@@ -1560,7 +1564,7 @@ def add_named_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] =
15601564
if isinstance(dims, str):
15611565
dims = (dims,)
15621566
for dim in dims:
1563-
if dim not in self.coords and dim is not None:
1567+
if dim not in self.coords_typed and dim is not None:
15641568
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
15651569
if any(var.name == dim for dim in dims if dim is not None):
15661570
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")

pymc/sampling/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def sample_posterior_predictive(
575575

576576
constant_coords = set()
577577
for dim, coord in trace_coords.items():
578-
current_coord = model.coords.get(dim, None)
578+
current_coord = model.coords_typed.get(dim, None)
579579
if (
580580
current_coord is not None
581581
and len(coord) == len(current_coord)

pymc/sampling/jax.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,9 @@ def sample_blackjax_nuts(
384384
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
385385

386386
coords = {
387-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
388-
for cname, cvals in model.coords.items()
389-
if cvals is not None
387+
cname: cvals_typed
388+
for cname, cvals_typed in model.coords_typed.items()
389+
if cvals_typed is not None
390390
}
391391

392392
dims = {
@@ -605,9 +605,9 @@ def sample_numpyro_nuts(
605605
vars_to_sample = list(get_default_varnames(var_names, include_transformed=keep_untransformed))
606606

607607
coords = {
608-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
609-
for cname, cvals in model.coords.items()
610-
if cvals is not None
608+
cname: cvals_typed
609+
for cname, cvals_typed in model.coords_typed.items()
610+
if cvals_typed is not None
611611
}
612612

613613
dims = {

pymc/stats/log_likelihood.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
from typing import Optional, Sequence, cast
1515

16-
import numpy as np
17-
1816
from arviz import InferenceData, dict_to_dataset
1917
from fastprogress import progress_bar
2018

@@ -117,10 +115,7 @@ def compute_log_likelihood(
117115
loglike_trace,
118116
library=pymc,
119117
dims={dname: list(dvals) for dname, dvals in model.named_vars_to_dims.items()},
120-
coords={
121-
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
122-
for cname, cvals in model.coords.items()
123-
},
118+
coords={cname: cvals_typed for cname, cvals_typed in model.coords_typed.items()},
124119
default_dims=list(sample_dims),
125120
skip_event_dims=True,
126121
)

pymc/util.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from pytensor.compile import SharedVariable
2828
from pytensor.graph.utils import ValidatingScratchpad
2929

30+
from pymc.exceptions import ShapeError
31+
3032

3133
class _UnsetType:
3234
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
@@ -511,8 +513,17 @@ def _add_future_warning_tag(var) -> None:
511513
new_tag.__dict__.setdefault(k, v)
512514
var.tag = new_tag
513515

514-
def _as_coord_vals(values: Sequence) -> np.ndarray:
515-
"""Coerce a sequence coordinate values into a 1-dim array of values"""
516+
517+
def _as_coord_vals(values: Union[Sequence, np.ndarray]) -> np.ndarray:
518+
"""Coerce a sequence of coordinate values into a 1-dim array"""
519+
if isinstance(values, np.ndarray):
520+
if values.ndim != 1:
521+
raise ShapeError(
522+
"Coordinate values passed as a numpy array must be 1-dimensional",
523+
actual=values.ndim,
524+
expected=1,
525+
)
526+
return values
516527
arr = np.array(values)
517528
if arr.ndim > 1:
518529
arr = np.empty(len(values), dtype="O")

pymc/variational/opvi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ def var_to_data(self, shared: at.TensorVariable) -> xarray.Dataset:
11351135
for name, s, shape, dtype in self.ordering.values():
11361136
dims = self.model.named_vars_to_dims.get(name, None)
11371137
if dims is not None:
1138-
coords = {d: np.array(self.model.coords[d]) for d in dims}
1138+
coords = {d: self.model.coords_typed[d] for d in dims}
11391139
else:
11401140
coords = None
11411141
values = shared_nda[s].reshape(shape).astype(dtype)

tests/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def test_nested_model_coords():
733733
c = pm.HalfNormal("c", dims="dim3")
734734
d = pm.Normal("d", b, c, dims="dim2")
735735
e = pm.Normal("e", a[None] + d[:, None], dims=("dim2", "dim1"))
736-
assert m1.coords is m2.coords
736+
assert m1.coords == m2.coords
737737
assert m1.dim_lengths is m2.dim_lengths
738738
assert set(m2.named_vars_to_dims) < set(m1.named_vars_to_dims)
739739

0 commit comments

Comments
 (0)