Skip to content

Commit 07679ec

Browse files
kc611ricardoV94
andauthored
Avoid unclear TypeError when using theano.shared variables as input to distribution parameters (#4445)
* Added default testvalue support for theano.shared Co-authored-by: Ricardo <[email protected]>
1 parent 2d3ec8f commit 07679ec

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- We upgraded to `Theano-PyMC v1.1.2` which [includes bugfixes](https://github.com/pymc-devs/aesara/compare/rel-1.1.0...rel-1.1.2) for warning floods and compiledir locking (see [#4444](https://github.com/pymc-devs/pymc3/pull/4444))
1212
- `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448))
1313
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
14+
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
1415

1516
## PyMC3 3.11.0 (21 January 2021)
1617

pymc3/distributions/distribution.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,17 @@ def default(self):
148148
def get_test_val(self, val, defaults):
149149
if val is None:
150150
for v in defaults:
151-
if hasattr(self, v) and np.all(np.isfinite(self.getattr_value(v))):
152-
return self.getattr_value(v)
153-
else:
154-
return self.getattr_value(val)
155-
156-
if val is None:
151+
if hasattr(self, v):
152+
attr_val = self.getattr_value(v)
153+
if np.all(np.isfinite(attr_val)):
154+
return attr_val
157155
raise AttributeError(
158156
"%s has no finite default value to use, "
159157
"checked: %s. Pass testval argument or "
160158
"adjust so value is finite." % (self, str(defaults))
161159
)
160+
else:
161+
return self.getattr_value(val)
162162

163163
def getattr_value(self, val):
164164
if isinstance(val, string_types):
@@ -167,7 +167,7 @@ def getattr_value(self, val):
167167
if isinstance(val, tt.TensorVariable):
168168
return val.tag.test_value
169169

170-
if isinstance(val, tt.sharedvar.TensorSharedVariable):
170+
if isinstance(val, tt.sharedvar.SharedVariable):
171171
return val.get_value()
172172

173173
if isinstance(val, theano_constant):

pymc3/tests/test_data_container.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pandas as pd
1717
import pytest
1818

19+
from theano import shared
20+
1921
import pymc3 as pm
2022

2123
from pymc3.tests.helpers import SeededTest
@@ -156,6 +158,26 @@ def test_shared_data_as_rv_input(self):
156158
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
157159
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
158160

161+
def test_shared_scalar_as_rv_input(self):
162+
# See https://github.com/pymc-devs/pymc3/issues/3139
163+
with pm.Model() as m:
164+
shared_var = shared(5.0)
165+
v = pm.Normal("v", mu=shared_var, shape=1)
166+
167+
np.testing.assert_allclose(
168+
v.logp({"v": [5.0]}),
169+
-0.91893853,
170+
rtol=1e-5,
171+
)
172+
173+
shared_var.set_value(10.0)
174+
175+
np.testing.assert_allclose(
176+
v.logp({"v": [10.0]}),
177+
-0.91893853,
178+
rtol=1e-5,
179+
)
180+
159181
def test_creation_of_data_outside_model_context(self):
160182
with pytest.raises((IndexError, TypeError)) as error:
161183
pm.Data("data", [1.1, 2.2, 3.3])

0 commit comments

Comments
 (0)