Skip to content

issue #5791, dims & cords inference from xarray #6514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
45 changes: 39 additions & 6 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from typing import Dict, Optional, Sequence, Tuple, Union, cast

import numpy as np
import pandas as pd
import pytensor
import pytensor.tensor as at
import xarray as xr

from pytensor.compile.sharedvalue import SharedVariable
from pytensor.raise_op import Assert
Expand Down Expand Up @@ -205,17 +207,17 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:

def determine_coords(
model,
value,
value: Union[pd.DataFrame, pd.Series, xr.DataArray],
dims: Optional[Sequence[Optional[str]]] = None,
coords: Optional[Dict[str, Sequence]] = None,
) -> Tuple[Dict[str, Sequence], Sequence[Optional[str]]]:
"""Determines coordinate values from data or the model (via ``dims``)."""
if coords is None:
coords = {}

dim_name = None
# If value is a df or a series, we interpret the index as coords:
if hasattr(value, "index"):
dim_name = None
if dims is not None:
dim_name = dims[0]
if dim_name is None and value.index.name is not None:
Expand All @@ -225,14 +227,19 @@ def determine_coords(

# If value is a df, we also interpret the columns as coords:
if hasattr(value, "columns"):
dim_name = None
if dims is not None:
dim_name = dims[1]
if dim_name is None and value.columns.name is not None:
dim_name = value.columns.name
if dim_name is not None:
coords[dim_name] = value.columns

if isinstance(value, xr.DataArray):
if dims is not None:
for dim in dims:
dim_name = dim
coords[dim_name] = value[dim]

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
raise pm.exceptions.ShapeError(
Expand All @@ -259,19 +266,27 @@ def ConstantData(
dims: Optional[Sequence[str]] = None,
coords: Optional[Dict[str, Sequence]] = None,
export_index_as_coords=False,
infer_dims_and_coords=False,
**kwargs,
) -> TensorConstant:
"""Alias for ``pm.Data(..., mutable=False)``.

Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model.
For more information, please reference :class:`pymc.Data`.
"""
if export_index_as_coords:
infer_dims_and_coords = export_index_as_coords
warnings.warn(
"Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.",
DeprecationWarning,
)

var = Data(
name,
value,
dims=dims,
coords=coords,
export_index_as_coords=export_index_as_coords,
infer_dims_and_coords=infer_dims_and_coords,
mutable=False,
**kwargs,
)
Expand All @@ -285,19 +300,27 @@ def MutableData(
dims: Optional[Sequence[str]] = None,
coords: Optional[Dict[str, Sequence]] = None,
export_index_as_coords=False,
infer_dims_and_coords=False,
**kwargs,
) -> SharedVariable:
"""Alias for ``pm.Data(..., mutable=True)``.

Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable`
with the model. For more information, please reference :class:`pymc.Data`.
"""
if export_index_as_coords:
infer_dims_and_coords = export_index_as_coords
warnings.warn(
"Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.",
DeprecationWarning,
)

var = Data(
name,
value,
dims=dims,
coords=coords,
export_index_as_coords=export_index_as_coords,
infer_dims_and_coords=infer_dims_and_coords,
mutable=True,
**kwargs,
)
Expand All @@ -311,6 +334,7 @@ def Data(
dims: Optional[Sequence[str]] = None,
coords: Optional[Dict[str, Sequence]] = None,
export_index_as_coords=False,
infer_dims_and_coords=False,
mutable: Optional[bool] = None,
**kwargs,
) -> Union[SharedVariable, TensorConstant]:
Expand Down Expand Up @@ -347,7 +371,8 @@ def Data(
names.
coords : dict, optional
Coordinate values to set for new dimensions introduced by this ``Data`` variable.
export_index_as_coords : bool, default=False
export_index_as_coords : deprecated, previous version of "infer_dims_and_coords"
infer_dims_and_coords : bool, default=False
If True, the ``Data`` container will try to infer what the coordinates
and dimension names should be if there is an index in ``value``.
mutable : bool, optional
Expand Down Expand Up @@ -426,7 +451,15 @@ def Data(
)

# Optionally infer coords and dims from the input value.

if export_index_as_coords:
infer_dims_and_coords = export_index_as_coords
warnings.warn(
"Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.",
DeprecationWarning,
)

if infer_dims_and_coords:
coords, dims = determine_coords(model, value, dims)

if dims:
Expand Down
12 changes: 12 additions & 0 deletions pymc/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,18 @@ def test_implicit_coords_dataframe(self):
assert "columns" in pmodel.coords
assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")}

def test_implicit_coords_xarray(self):
xr = pytest.importorskip("xarray")
data = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y", "x"))
with pm.Model() as pmodel:
with pytest.warns(DeprecationWarning):
pm.ConstantData("observations", data, dims=("x", "y"), export_index_as_coords=True)
assert "x" in pmodel.coords
assert "y" in pmodel.coords
assert pmodel.named_vars_to_dims == {"observations": ("x", "y")}
assert tuple(pmodel.coords["x"]) == (data.coords["x"],)
assert tuple(pmodel.coords["y"]) == (data.coords["y"],)

def test_data_kwargs(self):
strict_value = True
allow_downcast_value = False
Expand Down