diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index da62607513..6d58c0adca 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -88,6 +88,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - New features for BART: - Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). - Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). +- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098) - ... diff --git a/pymc/data.py b/pymc/data.py index 649d59828c..804595831b 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -464,8 +464,8 @@ def align_minibatches(batches=None): class Data: - """Data container class that wraps the Aesara ``SharedVariable`` class - and lets the model be aware of its inputs and outputs. + """Data container class that wraps :func:`aesara.shared` and lets + the model be aware of its inputs and outputs. Parameters ---------- @@ -478,10 +478,12 @@ class Data: random variables). Use this when `value` is a pandas Series or DataFrame. The `dims` will then be the name of the Series / DataFrame's columns. See ArviZ documentation for more information about dimensions and coordinates: - https://arviz-devs.github.io/arviz/notebooks/Introduction.html + :ref:`arviz:quickstart`. export_index_as_coords: bool, optional, default=False If True, the `Data` container will try to infer what the coordinates should be if there is an index in `value`. + **kwargs: dict, optional + Extra arguments passed to :func:`aesara.shared`. Examples -------- @@ -512,7 +514,15 @@ class Data: https://docs.pymc.io/notebooks/data_container.html """ - def __new__(self, name, value, *, dims=None, export_index_as_coords=False): + def __new__( + self, + name, + value, + *, + dims=None, + export_index_as_coords=False, + **kwargs, + ): if isinstance(value, list): value = np.array(value) @@ -528,7 +538,7 @@ def __new__(self, name, value, *, dims=None, export_index_as_coords=False): # `pandas_to_array` takes care of parameter `value` and # transforms it to something digestible for pymc - shared_object = aesara.shared(pandas_to_array(value), name) + shared_object = aesara.shared(pandas_to_array(value), name, **kwargs) if isinstance(dims, str): dims = (dims,) diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index cfbdaa32a9..1abc6c0329 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -366,6 +366,19 @@ def test_implicit_coords_dataframe(self): assert "columns" in pmodel.coords assert pmodel.RV_dims == {"observations": ("rows", "columns")} + def test_data_kwargs(self): + strict_value = True + allow_downcast_value = False + with pm.Model(): + data = pm.Data( + "data", + value=[[1.0], [2.0], [3.0]], + strict=strict_value, + allow_downcast=allow_downcast_value, + ) + assert data.container.strict is strict_value + assert data.container.allow_downcast is allow_downcast_value + def test_data_naming(): """