diff --git a/pymc/data.py b/pymc/data.py index b7d0dcac11..19216f4550 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -22,8 +22,10 @@ from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np +import pandas as pd import pytensor import pytensor.tensor as at +import xarray as xr from pytensor.compile.sharedvalue import SharedVariable from pytensor.raise_op import Assert @@ -205,17 +207,17 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, - value, + value: Union[pd.DataFrame, pd.Series, xr.DataArray], dims: Optional[Sequence[Optional[str]]] = None, - coords: Optional[Dict[str, Sequence]] = None, -) -> Tuple[Dict[str, Sequence], Sequence[Optional[str]]]: + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, +) -> Tuple[Dict[str, Union[Sequence, np.ndarray]], Sequence[Optional[str]]]: """Determines coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} + dim_name = None # If value is a df or a series, we interpret the index as coords: if hasattr(value, "index"): - dim_name = None if dims is not None: dim_name = dims[0] if dim_name is None and value.index.name is not None: @@ -225,7 +227,6 @@ def determine_coords( # If value is a df, we also interpret the columns as coords: if hasattr(value, "columns"): - dim_name = None if dims is not None: dim_name = dims[1] if dim_name is None and value.columns.name is not None: @@ -233,6 +234,13 @@ def determine_coords( if dim_name is not None: coords[dim_name] = value.columns + if isinstance(value, xr.DataArray): + if dims is not None: + for dim in dims: + dim_name = dim + # str is applied because dim entries may be None + coords[str(dim_name)] = value[dim].to_numpy() + if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: raise pm.exceptions.ShapeError( @@ -257,8 +265,9 @@ def ConstantData( value, *, dims: Optional[Sequence[str]] = None, - coords: Optional[Dict[str, Sequence]] = None, + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, export_index_as_coords=False, + infer_dims_and_coords=False, **kwargs, ) -> TensorConstant: """Alias for ``pm.Data(..., mutable=False)``. @@ -266,12 +275,19 @@ def ConstantData( Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model. For more information, please reference :class:`pymc.Data`. """ + if export_index_as_coords: + infer_dims_and_coords = export_index_as_coords + warnings.warn( + "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + DeprecationWarning, + ) + var = Data( name, value, dims=dims, coords=coords, - export_index_as_coords=export_index_as_coords, + infer_dims_and_coords=infer_dims_and_coords, mutable=False, **kwargs, ) @@ -283,8 +299,9 @@ def MutableData( value, *, dims: Optional[Sequence[str]] = None, - coords: Optional[Dict[str, Sequence]] = None, + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, export_index_as_coords=False, + infer_dims_and_coords=False, **kwargs, ) -> SharedVariable: """Alias for ``pm.Data(..., mutable=True)``. @@ -292,12 +309,19 @@ def MutableData( Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable` with the model. For more information, please reference :class:`pymc.Data`. """ + if export_index_as_coords: + infer_dims_and_coords = export_index_as_coords + warnings.warn( + "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + DeprecationWarning, + ) + var = Data( name, value, dims=dims, coords=coords, - export_index_as_coords=export_index_as_coords, + infer_dims_and_coords=infer_dims_and_coords, mutable=True, **kwargs, ) @@ -309,8 +333,9 @@ def Data( value, *, dims: Optional[Sequence[str]] = None, - coords: Optional[Dict[str, Sequence]] = None, + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, export_index_as_coords=False, + infer_dims_and_coords=False, mutable: Optional[bool] = None, **kwargs, ) -> Union[SharedVariable, TensorConstant]: @@ -347,7 +372,9 @@ def Data( names. coords : dict, optional Coordinate values to set for new dimensions introduced by this ``Data`` variable. - export_index_as_coords : bool, default=False + export_index_as_coords : bool + Deprecated, previous version of "infer_dims_and_coords" + infer_dims_and_coords : bool, default=False If True, the ``Data`` container will try to infer what the coordinates and dimension names should be if there is an index in ``value``. mutable : bool, optional @@ -427,6 +454,13 @@ def Data( # Optionally infer coords and dims from the input value. if export_index_as_coords: + infer_dims_and_coords = export_index_as_coords + warnings.warn( + "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + DeprecationWarning, + ) + + if infer_dims_and_coords: coords, dims = determine_coords(model, value, dims) if dims: diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index 1a13b2176a..88db4bb2b9 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -405,6 +405,18 @@ def test_implicit_coords_dataframe(self): assert "columns" in pmodel.coords assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")} + def test_implicit_coords_xarray(self): + xr = pytest.importorskip("xarray") + data = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y", "x")) + with pm.Model() as pmodel: + with pytest.warns(DeprecationWarning): + pm.ConstantData("observations", data, dims=("x", "y"), export_index_as_coords=True) + assert "x" in pmodel.coords + assert "y" in pmodel.coords + assert pmodel.named_vars_to_dims == {"observations": ("x", "y")} + assert tuple(pmodel.coords["x"]) == tuple(data.coords["x"].to_numpy()) + assert tuple(pmodel.coords["y"]) == tuple(data.coords["y"].to_numpy()) + def test_data_kwargs(self): strict_value = True allow_downcast_value = False