Skip to content

Commit 890486d

Browse files
Remove initval from dist API and add dedicated test suite
Closes #4893
1 parent 16502aa commit 890486d

File tree

4 files changed

+70
-21
lines changed

4 files changed

+70
-21
lines changed

.github/workflows/pytest.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ jobs:
3030
# → pytest will run only these files
3131
- |
3232
--ignore=pymc3/tests/test_distributions_timeseries.py
33+
--ignore=pymc3/tests/test_initvals.py
3334
--ignore=pymc3/tests/test_mixture.py
3435
--ignore=pymc3/tests/test_model_graph.py
3536
--ignore=pymc3/tests/test_modelcontext.py
@@ -60,7 +61,9 @@ jobs:
6061
--ignore=pymc3/tests/test_distributions_random.py
6162
--ignore=pymc3/tests/test_idata_conversion.py
6263
63-
- pymc3/tests/test_distributions.py
64+
- |
65+
pymc3/tests/test_initvals.py
66+
pymc3/tests/test_distributions.py
6467
6568
- |
6669
pymc3/tests/test_modelcontext.py
@@ -153,6 +156,7 @@ jobs:
153156
floatx: [float32, float64]
154157
test-subset:
155158
- |
159+
pymc3/tests/test_initvals.py
156160
pymc3/tests/test_distributions_random.py
157161
pymc3/tests/test_distributions_timeseries.py
158162
- |

pymc3/distributions/continuous.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,9 @@ class Flat(Continuous):
361361
rv_op = flat
362362

363363
@classmethod
364-
def dist(cls, *, size=None, initval=None, **kwargs):
365-
if initval is None:
366-
initval = np.full(size, floatX(0.0))
364+
def dist(cls, *, size=None, **kwargs):
367365
res = super().dist([], size=size, **kwargs)
368-
res.tag.test_value = initval
366+
res.tag.test_value = np.full(size, floatX(0.0))
369367
return res
370368

371369
def logp(value):
@@ -425,11 +423,9 @@ class HalfFlat(PositiveContinuous):
425423
rv_op = halfflat
426424

427425
@classmethod
428-
def dist(cls, *, size=None, initval=None, **kwargs):
429-
if initval is None:
430-
initval = np.full(size, floatX(1.0))
426+
def dist(cls, *, size=None, **kwargs):
431427
res = super().dist([], size=size, **kwargs)
432-
res.tag.test_value = initval
428+
res.tag.test_value = np.full(size, floatX(1.0))
433429
return res
434430

435431
def logp(value):

pymc3/distributions/distribution.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ def __new__(
202202
)
203203
dims = convert_dims(dims)
204204

205-
# Create the RV without specifying initval, because the initval may have a shape
206-
# that only matches after replicating with a size implied by dims (see below).
207-
rv_out = cls.dist(*args, rng=rng, initval=None, **kwargs)
205+
# Create the RV without dims information, because that's not something tracked at the Aesara level.
206+
# If necessary we'll later replicate to a different size implied by already known dims.
207+
rv_out = cls.dist(*args, rng=rng, **kwargs)
208208
ndim_actual = rv_out.ndim
209209
resize_shape = None
210210

@@ -242,7 +242,6 @@ def dist(
242242
*,
243243
shape: Optional[Shape] = None,
244244
size: Optional[Size] = None,
245-
initval=None,
246245
**kwargs,
247246
) -> RandomVariable:
248247
"""Creates a RandomVariable corresponding to the `cls` distribution.
@@ -258,25 +257,27 @@ def dist(
258257
all the dimensions that the RV would get if no shape/size/dims were passed at all.
259258
size : int, tuple, Variable, optional
260259
For creating the RV like in Aesara/NumPy.
261-
initival : optional
262-
Test value to be attached to the output RV.
263-
Must match its shape exactly.
264260
265261
Returns
266262
-------
267263
rv : RandomVariable
268264
The created RV.
269265
"""
270266
if "testval" in kwargs:
271-
initval = kwargs.pop("testval")
267+
kwargs.pop("testval")
272268
warnings.warn(
273-
"The `testval` argument is deprecated. "
274-
"Use `initval` to set initial values for a `Model`; "
275-
"otherwise, set test values on Aesara parameters explicitly "
276-
"when attempting to use Aesara's test value debugging features.",
269+
"The `.dist(testval=...)` argument is deprecated and has no effect. "
270+
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
271+
"For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
277272
DeprecationWarning,
278273
stacklevel=2,
279274
)
275+
if "initval" in kwargs:
276+
raise TypeError(
277+
"Unexpected keyword argument `initval`. "
278+
"This argument is not available for the `.dist()` API."
279+
)
280+
280281
if "dims" in kwargs:
281282
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
282283
if shape is not None and size is not None:

pymc3/tests/test_initvals.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2020 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
16+
import pymc3 as pm
17+
18+
19+
def transform_fwd(rv, expected_untransformed):
20+
return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval()
21+
22+
23+
class TestInitvalAssignment:
24+
def test_dist_warnings_and_errors(self):
25+
with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"):
26+
rv = pm.Exponential.dist(lam=1, testval=0.5)
27+
assert not hasattr(rv.tag, "test_value")
28+
29+
with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."):
30+
pm.Normal.dist(1, 2, initval=None)
31+
pass
32+
33+
def test_new_warnings(self):
34+
with pm.Model() as pmodel:
35+
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
36+
rv = pm.Uniform("u", 0, 1, testval=0.75)
37+
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75)
38+
pass
39+
40+
41+
class TestSpecialDistributions:
42+
def test_automatically_assigned_test_values(self):
43+
# ...because they don't have random number generators.
44+
rv = pm.Flat.dist()
45+
assert hasattr(rv.tag, "test_value")
46+
rv = pm.HalfFlat.dist()
47+
assert hasattr(rv.tag, "test_value")
48+
pass

0 commit comments

Comments
 (0)