Skip to content

Commit e97ad4e

Browse files
Optionally pass coords when creating pm.Data variables
Needed to create resizable dims with coordinate values, because coords passed to the model will become immutable dims. See #5763. Also, because `pm.Data` variables can be N-dimensional, the coordinate values can't reliably be taken from the value.
1 parent 18f6fe5 commit e97ad4e

File tree

3 files changed

+66
-16
lines changed

3 files changed

+66
-16
lines changed

pymc/data.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import warnings
2121

2222
from copy import copy
23-
from typing import Any, Dict, List, Optional, Sequence, Union, cast
23+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
2424

2525
import aesara
2626
import aesara.tensor as at
@@ -466,9 +466,15 @@ def align_minibatches(batches=None):
466466
rng.seed()
467467

468468

469-
def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict[str, Sequence]:
469+
def determine_coords(
470+
model,
471+
value,
472+
dims: Optional[Sequence[Optional[str]]] = None,
473+
coords: Optional[Dict[str, Sequence]] = None,
474+
) -> Tuple[Dict[str, Sequence], Sequence[Optional[str]]]:
470475
"""Determines coordinate values from data or the model (via ``dims``)."""
471-
coords = {}
476+
if coords is None:
477+
coords = {}
472478

473479
# If value is a df or a series, we interpret the index as coords:
474480
if hasattr(value, "index"):
@@ -499,17 +505,22 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
499505
)
500506
for size, dim in zip(value.shape, dims):
501507
coord = model.coords.get(dim, None)
502-
if coord is None:
508+
if coord is None and dim is not None:
503509
coords[dim] = range(size)
504510

505-
return coords
511+
if dims is None:
512+
# TODO: Also determine dim names from the index
513+
dims = [None] * np.ndim(value)
514+
515+
return coords, dims
506516

507517

508518
def ConstantData(
509519
name: str,
510520
value,
511521
*,
512522
dims: Optional[Sequence[str]] = None,
523+
coords: Optional[Dict[str, Sequence]] = None,
513524
export_index_as_coords=False,
514525
**kwargs,
515526
) -> TensorConstant:
@@ -522,6 +533,7 @@ def ConstantData(
522533
name,
523534
value,
524535
dims=dims,
536+
coords=coords,
525537
export_index_as_coords=export_index_as_coords,
526538
mutable=False,
527539
**kwargs,
@@ -534,6 +546,7 @@ def MutableData(
534546
value,
535547
*,
536548
dims: Optional[Sequence[str]] = None,
549+
coords: Optional[Dict[str, Sequence]] = None,
537550
export_index_as_coords=False,
538551
**kwargs,
539552
) -> SharedVariable:
@@ -546,6 +559,7 @@ def MutableData(
546559
name,
547560
value,
548561
dims=dims,
562+
coords=coords,
549563
export_index_as_coords=export_index_as_coords,
550564
mutable=True,
551565
**kwargs,
@@ -558,6 +572,7 @@ def Data(
558572
value,
559573
*,
560574
dims: Optional[Sequence[str]] = None,
575+
coords: Optional[Dict[str, Sequence]] = None,
561576
export_index_as_coords=False,
562577
mutable: Optional[bool] = None,
563578
**kwargs,
@@ -588,9 +603,11 @@ def Data(
588603
:ref:`arviz:quickstart`.
589604
If this parameter is not specified, the random variables will not have dimension
590605
names.
606+
coords : dict, optional
607+
Coordinate values to set for new dimensions introduced by this ``Data`` variable.
591608
export_index_as_coords : bool, default=False
592-
If True, the ``Data`` container will try to infer what the coordinates should be
593-
if there is an index in ``value``.
609+
If True, the ``Data`` container will try to infer what the coordinates
610+
and dimension names should be if there is an index in ``value``.
594611
mutable : bool, optional
595612
Switches between creating a :class:`~aesara.compile.sharedvalue.SharedVariable`
596613
(``mutable=True``) vs. creating a :class:`~aesara.tensor.TensorConstant`
@@ -624,6 +641,9 @@ def Data(
624641
... model.set_data('data', data_vals)
625642
... idatas.append(pm.sample())
626643
"""
644+
if coords is None:
645+
coords = {}
646+
627647
if isinstance(value, list):
628648
value = np.array(value)
629649

@@ -665,15 +685,27 @@ def Data(
665685
expected=x.ndim,
666686
)
667687

668-
coords = determine_coords(model, value, dims)
669-
688+
# Optionally infer coords and dims from the input value.
670689
if export_index_as_coords:
671-
model.add_coords(coords)
672-
elif dims:
690+
coords, dims = determine_coords(model, value, dims)
691+
692+
if dims:
693+
if not mutable:
694+
# Use the dimension lengths from the before it was tensorified.
695+
# These can still be tensors, but in many cases they are numeric.
696+
xshape = np.shape(arr)
697+
else:
698+
xshape = x.shape
673699
# Register new dimension lengths
674700
for d, dname in enumerate(dims):
675701
if not dname in model.dim_lengths:
676-
model.add_coord(dname, values=None, length=x.shape[d])
702+
model.add_coord(
703+
name=dname,
704+
# Note: Coordinate values can't be taken from
705+
# the value, because it could be N-dimensional.
706+
values=coords.get(dname, None),
707+
length=xshape[d],
708+
)
677709

678710
model.add_random_variable(x, dims=dims)
679711

pymc/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def add_coord(
10721072
name: str,
10731073
values: Optional[Sequence] = None,
10741074
*,
1075-
length: Optional[Variable] = None,
1075+
length: Optional[Union[int, Variable]] = None,
10761076
):
10771077
"""Registers a dimension coordinate with the model.
10781078
@@ -1085,7 +1085,7 @@ def add_coord(
10851085
Coordinate values or ``None`` (for auto-numbering).
10861086
If ``None`` is passed, a ``length`` must be specified.
10871087
length : optional, scalar
1088-
A symbolic scalar of the dimensions length.
1088+
A scalar of the dimensions length.
10891089
Defaults to ``aesara.shared(len(values))``.
10901090
"""
10911091
if name in {"draw", "chain", "__sample__"}:
@@ -1097,7 +1097,9 @@ def add_coord(
10971097
raise ValueError(
10981098
f"Either `values` or `length` must be specified for the '{name}' dimension."
10991099
)
1100-
if length is not None and not isinstance(length, Variable):
1100+
if isinstance(length, int):
1101+
length = at.constant(length)
1102+
elif length is not None and not isinstance(length, Variable):
11011103
raise ValueError(
11021104
f"The `length` passed for the '{name}' coord must be an Aesara Variable or None."
11031105
)
@@ -1116,7 +1118,7 @@ def add_coords(
11161118
self,
11171119
coords: Dict[str, Optional[Sequence]],
11181120
*,
1119-
lengths: Optional[Dict[str, Union[Variable, None]]] = None,
1121+
lengths: Optional[Dict[str, Optional[Union[int, Variable]]]] = None,
11201122
):
11211123
"""Vectorized version of ``Model.add_coord``."""
11221124
if coords is None:

pymc/tests/test_data_container.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,22 @@ def test_explicit_coords(self):
332332
assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable)
333333
assert pmodel.dim_lengths["columns"].eval() == 7
334334

335+
def test_set_coords_through_pmdata(self):
336+
with pm.Model() as pmodel:
337+
pm.ConstantData(
338+
"population", [100, 200], dims="city", coords={"city": ["Tinyvil", "Minitown"]}
339+
)
340+
pm.MutableData(
341+
"temperature",
342+
[[15, 20, 22, 17], [18, 22, 21, 12]],
343+
dims=("city", "season"),
344+
coords={"season": ["winter", "spring", "summer", "fall"]},
345+
)
346+
assert "city" in pmodel.coords
347+
assert "season" in pmodel.coords
348+
assert pmodel.coords["city"] == ("Tinyvil", "Minitown")
349+
assert pmodel.coords["season"] == ("winter", "spring", "summer", "fall")
350+
335351
def test_symbolic_coords(self):
336352
"""
337353
In v4 dimensions can be created without passing coordinate values.

0 commit comments

Comments
 (0)