From 1c9e54575416858279187f6037df424eac571132 Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Sun, 24 Oct 2021 08:49:43 -0700 Subject: [PATCH 1/4] Tests kwargs are passed from Data to aesara.shared --- pymc/tests/test_data_container.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index cfbdaa32a9..1abc6c0329 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -366,6 +366,19 @@ def test_implicit_coords_dataframe(self): assert "columns" in pmodel.coords assert pmodel.RV_dims == {"observations": ("rows", "columns")} + def test_data_kwargs(self): + strict_value = True + allow_downcast_value = False + with pm.Model(): + data = pm.Data( + "data", + value=[[1.0], [2.0], [3.0]], + strict=strict_value, + allow_downcast=allow_downcast_value, + ) + assert data.container.strict is strict_value + assert data.container.allow_downcast is allow_downcast_value + def test_data_naming(): """ From 55827e411225023ea52171dccb8fa2e00e7572dd Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Sun, 24 Oct 2021 08:51:07 -0700 Subject: [PATCH 2/4] Pass kwargs from Data to aesara.shared --- pymc/data.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 649d59828c..1b49aca84e 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -482,6 +482,8 @@ class Data: export_index_as_coords: bool, optional, default=False If True, the `Data` container will try to infer what the coordinates should be if there is an index in `value`. + **kwargs: dict, optional + Extra arguments passed to :func:`aesara.shared`. Examples -------- @@ -512,7 +514,15 @@ class Data: https://docs.pymc.io/notebooks/data_container.html """ - def __new__(self, name, value, *, dims=None, export_index_as_coords=False): + def __new__( + self, + name, + value, + *, + dims=None, + export_index_as_coords=False, + **kwargs, + ): if isinstance(value, list): value = np.array(value) @@ -528,7 +538,7 @@ def __new__(self, name, value, *, dims=None, export_index_as_coords=False): # `pandas_to_array` takes care of parameter `value` and # transforms it to something digestible for pymc - shared_object = aesara.shared(pandas_to_array(value), name) + shared_object = aesara.shared(pandas_to_array(value), name, **kwargs) if isinstance(dims, str): dims = (dims,) From b552b4a57bce794473a2181fc681c6332b77ea4d Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Sun, 24 Oct 2021 08:57:27 -0700 Subject: [PATCH 3/4] Document Data kwargs in release notes --- RELEASE-NOTES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index da62607513..6d58c0adca 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -88,6 +88,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - New features for BART: - Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). - Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). +- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098) - ... From 748ea70f84bf913221ea05730e13f7fce868e86f Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Thu, 4 Nov 2021 21:16:23 -0700 Subject: [PATCH 4/4] Convert Aesara references to intersphinx --- pymc/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 1b49aca84e..804595831b 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -464,8 +464,8 @@ def align_minibatches(batches=None): class Data: - """Data container class that wraps the Aesara ``SharedVariable`` class - and lets the model be aware of its inputs and outputs. + """Data container class that wraps :func:`aesara.shared` and lets + the model be aware of its inputs and outputs. Parameters ---------- @@ -478,7 +478,7 @@ class Data: random variables). Use this when `value` is a pandas Series or DataFrame. The `dims` will then be the name of the Series / DataFrame's columns. See ArviZ documentation for more information about dimensions and coordinates: - https://arviz-devs.github.io/arviz/notebooks/Introduction.html + :ref:`arviz:quickstart`. export_index_as_coords: bool, optional, default=False If True, the `Data` container will try to infer what the coordinates should be if there is an index in `value`.