Skip to content

Commit 6a75744

Browse files
Stop assigning initial values to test_value
1 parent 890486d commit 6a75744

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
- ⚠ PyMC3 now requires Scipy version `>= 1.4.1` (see [4857](https://github.com/pymc-devs/pymc3/pull/4857)).
77
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc3/pull/4471) and `3.11.2` release notes).
88
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
9-
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
9+
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`. Furthermore `initval` no longer assigns a `tag.test_value` on tensors since the initial values are now kept track of by the model object ([see #4913](https://github.com/pymc-devs/pymc3/pull/4913)).
1010
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc3/pull/4744)).
1111
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc3/pull/4769)).
1212
- ...

pymc3/distributions/distribution.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,14 @@ def __new__(
219219
# A batch size was specified through `dims`, or implied by `observed`.
220220
rv_out = change_rv_size(rv_var=rv_out, new_size=resize_shape, expand=True)
221221

222-
if initval is not None:
223-
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
224-
rv_out.tag.test_value = initval
225-
226222
rv_out = model.register_rv(
227-
rv_out, name, observed, total_size, dims=dims, transform=transform
223+
rv_out,
224+
name,
225+
observed,
226+
total_size,
227+
dims=dims,
228+
transform=transform,
229+
initval=initval,
228230
)
229231

230232
# add in pretty-printing support

pymc3/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -944,11 +944,10 @@ def cont_vars(self):
944944
return list(typefilter(self.value_vars, continuous_types))
945945

946946
def set_initval(self, rv_var, initval):
947-
initval = (
948-
rv_var.type.filter(initval)
949-
if initval is not None
950-
else getattr(rv_var.tag, "test_value", None)
951-
)
947+
if initval is not None:
948+
initval = rv_var.type.filter(initval)
949+
950+
test_value = getattr(rv_var.tag, "test_value", None)
952951

953952
rv_value_var = self.rvs_to_values[rv_var]
954953
transform = getattr(rv_value_var.tag, "transform", None)
@@ -982,7 +981,17 @@ def initval_to_rvval(value_var, value):
982981
initval_fn = aesara.function(
983982
[], rv_var, mode=mode, givens=givens, on_unused_input="ignore"
984983
)
985-
initval = initval_fn()
984+
try:
985+
initval = initval_fn()
986+
except NotImplementedError as ex:
987+
if "Cannot sample from" in ex.args[0]:
988+
# The RV does not have a random number generator.
989+
# Our last chance is to take the test_value.
990+
# Note that this is a workaround for Flat and HalfFlat
991+
# until an initval default mechanism is implemented (#4752).
992+
initval = test_value
993+
else:
994+
raise
986995

987996
self.initial_values[rv_value_var] = initval
988997

pymc3/tests/test_initvals.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_new_warnings(self):
3535
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
3636
rv = pm.Uniform("u", 0, 1, testval=0.75)
3737
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75)
38+
assert not hasattr(rv.tag, "test_value")
3839
pass
3940

4041

0 commit comments

Comments
 (0)