Skip to content

Commit 562e51e

Browse files
michaelraczyckiMichal Raczyckimichaelosthege
authored
issue #5791, dims & cords inference from xarray (#6514)
* added dim inference from xarray, deprecation warning and unittest for the new feature * fixed typo in warning * fixed accidental quotation around dim * fixed failing assertions * found and fixed cause of the failing test * changed the coords assertion according to suggested form * fixing mypy type missmatch * working on getting the test to work * removed typecasting to string on dim_name, was causing the mypy to fail * took care locally of mypy errors * Typo/formatting fixes --------- Co-authored-by: Michal Raczycki <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent e45e6c2 commit 562e51e

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

pymc/data.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from typing import Dict, Optional, Sequence, Tuple, Union, cast
2323

2424
import numpy as np
25+
import pandas as pd
2526
import pytensor
2627
import pytensor.tensor as at
28+
import xarray as xr
2729

2830
from pytensor.compile.sharedvalue import SharedVariable
2931
from pytensor.raise_op import Assert
@@ -205,17 +207,17 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
205207

206208
def determine_coords(
207209
model,
208-
value,
210+
value: Union[pd.DataFrame, pd.Series, xr.DataArray],
209211
dims: Optional[Sequence[Optional[str]]] = None,
210-
coords: Optional[Dict[str, Sequence]] = None,
211-
) -> Tuple[Dict[str, Sequence], Sequence[Optional[str]]]:
212+
coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None,
213+
) -> Tuple[Dict[str, Union[Sequence, np.ndarray]], Sequence[Optional[str]]]:
212214
"""Determines coordinate values from data or the model (via ``dims``)."""
213215
if coords is None:
214216
coords = {}
215217

218+
dim_name = None
216219
# If value is a df or a series, we interpret the index as coords:
217220
if hasattr(value, "index"):
218-
dim_name = None
219221
if dims is not None:
220222
dim_name = dims[0]
221223
if dim_name is None and value.index.name is not None:
@@ -225,14 +227,20 @@ def determine_coords(
225227

226228
# If value is a df, we also interpret the columns as coords:
227229
if hasattr(value, "columns"):
228-
dim_name = None
229230
if dims is not None:
230231
dim_name = dims[1]
231232
if dim_name is None and value.columns.name is not None:
232233
dim_name = value.columns.name
233234
if dim_name is not None:
234235
coords[dim_name] = value.columns
235236

237+
if isinstance(value, xr.DataArray):
238+
if dims is not None:
239+
for dim in dims:
240+
dim_name = dim
241+
# str is applied because dim entries may be None
242+
coords[str(dim_name)] = value[dim].to_numpy()
243+
236244
if isinstance(value, np.ndarray) and dims is not None:
237245
if len(dims) != value.ndim:
238246
raise pm.exceptions.ShapeError(
@@ -257,21 +265,29 @@ def ConstantData(
257265
value,
258266
*,
259267
dims: Optional[Sequence[str]] = None,
260-
coords: Optional[Dict[str, Sequence]] = None,
268+
coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None,
261269
export_index_as_coords=False,
270+
infer_dims_and_coords=False,
262271
**kwargs,
263272
) -> TensorConstant:
264273
"""Alias for ``pm.Data(..., mutable=False)``.
265274
266275
Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model.
267276
For more information, please reference :class:`pymc.Data`.
268277
"""
278+
if export_index_as_coords:
279+
infer_dims_and_coords = export_index_as_coords
280+
warnings.warn(
281+
"Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.",
282+
DeprecationWarning,
283+
)
284+
269285
var = Data(
270286
name,
271287
value,
272288
dims=dims,
273289
coords=coords,
274-
export_index_as_coords=export_index_as_coords,
290+
infer_dims_and_coords=infer_dims_and_coords,
275291
mutable=False,
276292
**kwargs,
277293
)
@@ -283,21 +299,29 @@ def MutableData(
283299
value,
284300
*,
285301
dims: Optional[Sequence[str]] = None,
286-
coords: Optional[Dict[str, Sequence]] = None,
302+
coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None,
287303
export_index_as_coords=False,
304+
infer_dims_and_coords=False,
288305
**kwargs,
289306
) -> SharedVariable:
290307
"""Alias for ``pm.Data(..., mutable=True)``.
291308
292309
Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable`
293310
with the model. For more information, please reference :class:`pymc.Data`.
294311
"""
312+
if export_index_as_coords:
313+
infer_dims_and_coords = export_index_as_coords
314+
warnings.warn(
315+
"Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.",
316+
DeprecationWarning,
317+
)
318+
295319
var = Data(
296320
name,
297321
value,
298322
dims=dims,
299323
coords=coords,
300-
export_index_as_coords=export_index_as_coords,
324+
infer_dims_and_coords=infer_dims_and_coords,
301325
mutable=True,
302326
**kwargs,
303327
)
@@ -309,8 +333,9 @@ def Data(
309333
value,
310334
*,
311335
dims: Optional[Sequence[str]] = None,
312-
coords: Optional[Dict[str, Sequence]] = None,
336+
coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None,
313337
export_index_as_coords=False,
338+
infer_dims_and_coords=False,
314339
mutable: Optional[bool] = None,
315340
**kwargs,
316341
) -> Union[SharedVariable, TensorConstant]:
@@ -347,7 +372,9 @@ def Data(
347372
names.
348373
coords : dict, optional
349374
Coordinate values to set for new dimensions introduced by this ``Data`` variable.
350-
export_index_as_coords : bool, default=False
375+
export_index_as_coords : bool
376+
Deprecated, previous version of "infer_dims_and_coords"
377+
infer_dims_and_coords : bool, default=False
351378
If True, the ``Data`` container will try to infer what the coordinates
352379
and dimension names should be if there is an index in ``value``.
353380
mutable : bool, optional
@@ -427,6 +454,13 @@ def Data(
427454

428455
# Optionally infer coords and dims from the input value.
429456
if export_index_as_coords:
457+
infer_dims_and_coords = export_index_as_coords
458+
warnings.warn(
459+
"Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.",
460+
DeprecationWarning,
461+
)
462+
463+
if infer_dims_and_coords:
430464
coords, dims = determine_coords(model, value, dims)
431465

432466
if dims:

pymc/tests/test_data.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,18 @@ def test_implicit_coords_dataframe(self):
405405
assert "columns" in pmodel.coords
406406
assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")}
407407

408+
def test_implicit_coords_xarray(self):
409+
xr = pytest.importorskip("xarray")
410+
data = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y", "x"))
411+
with pm.Model() as pmodel:
412+
with pytest.warns(DeprecationWarning):
413+
pm.ConstantData("observations", data, dims=("x", "y"), export_index_as_coords=True)
414+
assert "x" in pmodel.coords
415+
assert "y" in pmodel.coords
416+
assert pmodel.named_vars_to_dims == {"observations": ("x", "y")}
417+
assert tuple(pmodel.coords["x"]) == tuple(data.coords["x"].to_numpy())
418+
assert tuple(pmodel.coords["y"]) == tuple(data.coords["y"].to_numpy())
419+
408420
def test_data_kwargs(self):
409421
strict_value = True
410422
allow_downcast_value = False

0 commit comments

Comments
 (0)