Skip to content

Commit c9660b6

Browse files
committed
Fix _make_nice_attr_error_ for random and add it to SymbolicDistributions
1 parent fdfedd7 commit c9660b6

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

pymc/distributions/distribution.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def __new__(
292292

293293
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
294294
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
295-
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
295+
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
296296
return rv_out
297297

298298
@classmethod
@@ -351,7 +351,7 @@ def dist(
351351

352352
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
353353
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
354-
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
354+
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
355355
return rv_out
356356

357357

@@ -488,6 +488,10 @@ def __new__(
488488
functools.partial(str_for_symbolic_dist, formatting="latex"), rv_out
489489
)
490490

491+
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
492+
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
493+
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
494+
491495
return rv_out
492496

493497
@classmethod
@@ -546,10 +550,9 @@ def dist(
546550
# This is needed for resizing from dims in `__new__`
547551
rv_out.tag.ndim_supp = ndim_supp
548552

549-
# TODO: Create new attr error stating that these are not available for DerivedDistribution
550-
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
551-
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
552-
# rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
553+
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
554+
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
555+
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
553556
return rv_out
554557

555558

pymc/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3145,7 +3145,7 @@ def test_distinct_rvs():
31453145
[
31463146
("logp", r"pm.logp\(rv, x\)"),
31473147
("logcdf", r"pm.logcdf\(rv, x\)"),
3148-
("random", r"rv.eval\(\)"),
3148+
("random", r"pm.draw\(rv\)"),
31493149
],
31503150
)
31513151
def test_logp_gives_migration_instructions(method, newcode):

0 commit comments

Comments
 (0)