Skip to content

Commit 7608e30

Browse files
committed
Improve Categorical and Multinomial checks
- Fixes bug that occurred when constructing probability parameters from lists or tuples of Aesara variables - Improves logp checks for valid probabilty parameters
1 parent 14aa3d0 commit 7608e30

File tree

4 files changed

+61
-34
lines changed

4 files changed

+61
-34
lines changed

pymc/distributions/discrete.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import aesara.tensor as at
1717
import numpy as np
1818

19+
from aesara.tensor import TensorConstant
1920
from aesara.tensor.random.basic import (
2021
RandomVariable,
2122
ScipyRandomVariable,
@@ -1285,17 +1286,21 @@ def dist(cls, p=None, logit_p=None, **kwargs):
12851286
if logit_p is not None:
12861287
p = pm.math.softmax(logit_p, axis=-1)
12871288

1288-
if isinstance(p, np.ndarray) or isinstance(p, list):
1289-
if (np.asarray(p) < 0).any():
1290-
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
1291-
p_sum = np.sum([p], axis=-1)
1292-
if not np.all(np.isclose(p_sum, 1.0)):
1289+
p = at.as_tensor_variable(p)
1290+
if isinstance(p, TensorConstant):
1291+
p_ = np.asarray(p.data)
1292+
if np.any(p_ < 0):
1293+
raise ValueError(f"Negative `p` parameters are not valid, got: {p_}")
1294+
p_sum_ = np.sum([p_], axis=-1)
1295+
if not np.all(np.isclose(p_sum_, 1.0)):
12931296
warnings.warn(
1294-
f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
1297+
f"`p` parameters sum to {p_sum_}, instead of 1.0. "
1298+
"They will be automatically rescaled. "
1299+
"You can rescale them directly to get rid of this warning.",
12951300
UserWarning,
12961301
)
1297-
p = p / at.sum(p, axis=-1, keepdims=True)
1298-
p = at.as_tensor_variable(floatX(p))
1302+
p_ = p_ / at.sum(p_, axis=-1, keepdims=True)
1303+
p = at.as_tensor_variable(p_)
12991304
return super().dist([p], **kwargs)
13001305

13011306
def moment(rv, size, p):
@@ -1341,7 +1346,11 @@ def logp(value, p):
13411346
)
13421347

13431348
return check_parameters(
1344-
res, at.all(p_ >= 0, axis=-1), at.all(p <= 1, axis=-1), msg="0 <= p <=1"
1349+
res,
1350+
p_ >= 0,
1351+
p_ <= 1,
1352+
at.isclose(at.sum(p, axis=-1), 1),
1353+
msg="0 <= p <=1, sum(p) = 1",
13451354
)
13461355

13471356

pymc/distributions/multivariate.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from aesara.graph.op import Op
3131
from aesara.raise_op import Assert
3232
from aesara.sparse.basic import sp_sum
33-
from aesara.tensor import gammaln, sigmoid
33+
from aesara.tensor import TensorConstant, gammaln, sigmoid
3434
from aesara.tensor.nlinalg import det, eigh, matrix_inverse, trace
3535
from aesara.tensor.random.basic import dirichlet, multinomial, multivariate_normal
3636
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
@@ -543,16 +543,21 @@ class Multinomial(Discrete):
543543

544544
@classmethod
545545
def dist(cls, n, p, *args, **kwargs):
546-
if isinstance(p, np.ndarray) or isinstance(p, list):
547-
if (np.asarray(p) < 0).any():
548-
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
549-
p_sum = np.sum([p], axis=-1)
550-
if not np.all(np.isclose(p_sum, 1.0)):
546+
p = at.as_tensor_variable(p)
547+
if isinstance(p, TensorConstant):
548+
p_ = np.asarray(p.data)
549+
if np.any(p_ < 0):
550+
raise ValueError(f"Negative `p` parameters are not valid, got: {p_}")
551+
p_sum_ = np.sum([p_], axis=-1)
552+
if not np.all(np.isclose(p_sum_, 1.0)):
551553
warnings.warn(
552-
f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
554+
f"`p` parameters sum to {p_sum_}, instead of 1.0. "
555+
"They will be automatically rescaled. "
556+
"You can rescale them directly to get rid of this warning.",
553557
UserWarning,
554558
)
555-
p = p / at.sum(p, axis=-1, keepdims=True)
559+
p_ = p_ / at.sum(p_, axis=-1, keepdims=True)
560+
p = at.as_tensor_variable(p_)
556561
n = at.as_tensor_variable(n)
557562
p = at.as_tensor_variable(p)
558563
return super().dist([n, p], *args, **kwargs)
@@ -591,10 +596,11 @@ def logp(value, n, p):
591596
)
592597
return check_parameters(
593598
res,
599+
p >= 0,
594600
p <= 1,
595601
at.isclose(at.sum(p, axis=-1), 1),
596602
at.ge(n, 0),
597-
msg="p <= 1, sum(p) = 1, n >= 0",
603+
msg="0 <= p <= 1, sum(p) = 1, n >= 0",
598604
)
599605

600606

pymc/tests/distributions/test_discrete.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -497,32 +497,39 @@ def test_categorical_bounds(self):
497497
# entries if there is a single or pair number of negative values
498498
# and the rest are zero
499499
np.array([-1, -1, 0, 0]),
500+
at.as_tensor_variable([-1, -1, 0, 0]),
500501
],
501502
)
502503
def test_categorical_negative_p(self, p):
503-
with pytest.raises(ValueError, match=f"{p}"):
504+
with pytest.raises(ValueError, match="Negative `p` parameters are not valid"):
504505
with pm.Model():
505506
x = pm.Categorical("x", p=p)
506507

507508
def test_categorical_p_not_normalized(self):
508509
# test UserWarning is raised for p vals that sum to more than 1
509510
# and normaliation is triggered
510-
with pytest.warns(UserWarning, match="[5]"):
511+
with pytest.warns(UserWarning, match="They will be automatically rescaled"):
511512
with pm.Model() as m:
512513
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
513514
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)
514515

515516
def test_categorical_negative_p_symbolic(self):
517+
value = np.array([[1, 1, 1]])
518+
519+
x = at.scalar("x")
520+
invalid_dist = pm.Categorical.dist(p=[x, x, x])
521+
516522
with pytest.raises(ParameterValueError):
517-
value = np.array([[1, 1, 1]])
518-
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5]))
519-
pm.logp(invalid_dist, value).eval()
523+
pm.logp(invalid_dist, value).eval({x: -1 / 3})
520524

521525
def test_categorical_p_not_normalized_symbolic(self):
526+
value = np.array([[1, 1, 1]])
527+
528+
x = at.scalar("x")
529+
invalid_dist = pm.Categorical.dist(p=(x, x, x))
530+
522531
with pytest.raises(ParameterValueError):
523-
value = np.array([[1, 1, 1]])
524-
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
525-
pm.logp(invalid_dist, value).eval()
532+
pm.logp(invalid_dist, value).eval({x: 0.5})
526533

527534
@pytest.mark.parametrize("n", [2, 3, 4])
528535
def test_orderedlogistic(self, n):

pymc/tests/distributions/test_multivariate.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,14 @@ def test_multinomial_invalid_value(self):
548548

549549
def test_multinomial_negative_p(self):
550550
# test passing a list/numpy with negative p raises an immediate error
551-
with pytest.raises(ValueError, match="[-1, 1, 1]"):
551+
with pytest.raises(ValueError, match="Negative `p` parameters are not valid"):
552552
with pm.Model() as model:
553553
x = pm.Multinomial("x", n=5, p=[-1, 1, 1])
554554

555555
def test_multinomial_p_not_normalized(self):
556556
# test UserWarning is raised for p vals that sum to more than 1
557557
# and normaliation is triggered
558-
with pytest.warns(UserWarning, match="[5]"):
558+
with pytest.warns(UserWarning, match="They will be automatically rescaled"):
559559
with pm.Model() as m:
560560
x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1])
561561
# test stored p-vals have been normalised
@@ -564,18 +564,23 @@ def test_multinomial_p_not_normalized(self):
564564
def test_multinomial_negative_p_symbolic(self):
565565
# Passing symbolic negative p does not raise an immediate error, but evaluating
566566
# logp raises a ParameterValueError
567+
value = np.array([[1, 1, 1]])
568+
569+
x = at.scalar("x")
570+
invalid_dist = pm.Multinomial.dist(n=1, p=[x, x, x])
571+
567572
with pytest.raises(ParameterValueError):
568-
value = np.array([[1, 1, 1]])
569-
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5]))
570-
pm.logp(invalid_dist, value).eval()
573+
pm.logp(invalid_dist, value).eval({x: -1 / 3})
571574

572575
def test_multinomial_p_not_normalized_symbolic(self):
573576
# Passing symbolic p that do not add up to on does not raise any warning, but evaluating
574577
# logp raises a ParameterValueError
578+
value = np.array([[1, 1, 1]])
579+
580+
x = at.scalar("x")
581+
invalid_dist = pm.Multinomial.dist(n=1, p=(x, x, x))
575582
with pytest.raises(ParameterValueError):
576-
value = np.array([[1, 1, 1]])
577-
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5]))
578-
pm.logp(invalid_dist, value).eval()
583+
pm.logp(invalid_dist, value).eval({x: 0.5})
579584

580585
@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
581586
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)