Skip to content

Commit d596afb

Browse files
Remove deprecated Distribution kwargs (#7488)
* Remove deprecated Distribution kwargs Removing these to reduce cognitive load for an eventual migration to a function-based distribution. These were deprecated in #5109, in 2021, as part of pymc 4 being released. We're on 5.x so these should be safe. * Replace deprecated testval arg with initval in test I keep the test, since it seems to cover behaviour not tested elsewhere.
1 parent 253513b commit d596afb

File tree

4 files changed

+2
-100
lines changed

4 files changed

+2
-100
lines changed

pymc/distributions/distribution.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,6 @@ class DistributionMeta(ABCMeta):
9797
"""
9898

9999
def __new__(cls, name, bases, clsdict):
100-
# Forcefully deprecate old v3 `Distribution`s
101-
if "random" in clsdict:
102-
103-
def _random(*args, **kwargs):
104-
warnings.warn(
105-
"The old `Distribution.random` interface is deprecated.",
106-
FutureWarning,
107-
stacklevel=2,
108-
)
109-
return clsdict["random"](*args, **kwargs)
110-
111-
clsdict["random"] = _random
112-
113100
rv_op = clsdict.setdefault("rv_op", None)
114101
rv_type = clsdict.setdefault("rv_type", None)
115102

@@ -206,13 +193,6 @@ def support_point(op, rv, *dist_params):
206193
return new_cls
207194

208195

209-
def _make_nice_attr_error(oldcode: str, newcode: str):
210-
def fn(*args, **kwargs):
211-
raise AttributeError(f"The `{oldcode}` method was removed. Instead use `{newcode}`.`")
212-
213-
return fn
214-
215-
216196
class _class_or_instancemethod(classmethod):
217197
"""Allow a method to be called both as a classmethod and an instancemethod,
218198
giving priority to the instancemethod.
@@ -510,14 +490,6 @@ def __new__(
510490
"for a standalone distribution."
511491
)
512492

513-
if "testval" in kwargs:
514-
initval = kwargs.pop("testval")
515-
warnings.warn(
516-
"The `testval` argument is deprecated; use `initval`.",
517-
FutureWarning,
518-
stacklevel=2,
519-
)
520-
521493
if not isinstance(name, string_types):
522494
raise TypeError(f"Name needs to be a string but got: {name}")
523495

@@ -551,10 +523,6 @@ def __new__(
551523
rv_out._repr_latex_ = types.MethodType(
552524
functools.partial(str_for_dist, formatting="latex"), rv_out
553525
)
554-
555-
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
556-
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
557-
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
558526
return rv_out
559527

560528
@classmethod
@@ -582,15 +550,6 @@ def dist(
582550
rv : TensorVariable
583551
The created random variable tensor.
584552
"""
585-
if "testval" in kwargs:
586-
kwargs.pop("testval")
587-
warnings.warn(
588-
"The `.dist(testval=...)` argument is deprecated and has no effect. "
589-
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
590-
"For using PyTensor's test value features, you must assign the `.tag.test_value` yourself.",
591-
FutureWarning,
592-
stacklevel=2,
593-
)
594553
if "initval" in kwargs:
595554
raise TypeError(
596555
"Unexpected keyword argument `initval`. "
@@ -617,9 +576,6 @@ def dist(
617576
create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
618577
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
619578

620-
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
621-
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
622-
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
623579
_add_future_warning_tag(rv_out)
624580
return rv_out
625581

tests/distributions/test_distribution.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -84,31 +84,6 @@ def test_issue_4499(self):
8484
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10)
8585

8686

87-
@pytest.mark.parametrize(
88-
"method,newcode",
89-
[
90-
("logp", r"pm.logp\(rv, x\)"),
91-
("logcdf", r"pm.logcdf\(rv, x\)"),
92-
("random", r"pm.draw\(rv\)"),
93-
],
94-
)
95-
def test_logp_gives_migration_instructions(method, newcode):
96-
rv = pm.Normal.dist()
97-
f = getattr(rv, method)
98-
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
99-
f()
100-
101-
# A dim-induced resize of the rv created by the `.dist()` API,
102-
# happening in Distribution.__new__ would make us loose the monkeypatches.
103-
# So this triggers it to test if the monkeypatch still works.
104-
with pm.Model(coords={"year": [2019, 2021, 2022]}):
105-
rv = pm.Normal("n", dims="year")
106-
f = getattr(rv, method)
107-
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
108-
f()
109-
pass
110-
111-
11287
def test_all_distributions_have_support_points():
11388
import pymc.distributions as dist_module
11489

tests/model/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,8 @@ def test_initial_point():
660660

661661
b_initval = np.array(0.3, dtype=pytensor.config.floatX)
662662

663-
with pytest.warns(FutureWarning), model:
664-
b = pm.Uniform("b", testval=b_initval)
663+
with model:
664+
b = pm.Uniform("b", initval=b_initval)
665665

666666
b_initval_trans = model.rvs_to_transforms[b].forward(b_initval, *b.owner.inputs).eval()
667667

tests/test_initial_point.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import cloudpickle
1515
import numpy as np
16-
import numpy.testing as npt
1716
import pytensor
1817
import pytensor.tensor as pt
1918
import pytest
@@ -34,34 +33,6 @@ def transform_back(rv, transformed, model) -> np.ndarray:
3433
return model.rvs_to_transforms[rv].backward(transformed, *rv.owner.inputs).eval()
3534

3635

37-
class TestInitvalAssignment:
38-
def test_dist_warnings_and_errors(self):
39-
with pytest.warns(FutureWarning, match="argument is deprecated and has no effect"):
40-
rv = pm.Exponential.dist(lam=1, testval=0.5)
41-
assert not hasattr(rv.tag, "test_value")
42-
43-
with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."):
44-
pm.Normal.dist(1, 2, initval=None)
45-
pass
46-
47-
def test_new_warnings(self):
48-
with pm.Model() as pmodel:
49-
with pytest.warns(FutureWarning, match="`testval` argument is deprecated"):
50-
rv = pm.Uniform("u", 0, 1, testval=0.75)
51-
initial_point = pmodel.initial_point(random_seed=0)
52-
npt.assert_allclose(
53-
initial_point["u_interval__"], transform_fwd(rv, 0.75, model=pmodel)
54-
)
55-
assert not hasattr(rv.tag, "test_value")
56-
pass
57-
58-
def test_valid_string_strategy(self):
59-
with pm.Model() as pmodel:
60-
pm.Uniform("x", 0, 1, size=2, initval="unknown")
61-
with pytest.raises(ValueError, match="Invalid string strategy: unknown"):
62-
pmodel.initial_point(random_seed=0)
63-
64-
6536
class TestInitvalEvaluation:
6637
def test_make_initial_point_fns_per_chain_checks_kwargs(self):
6738
with pm.Model() as pmodel:

0 commit comments

Comments
 (0)