Skip to content

Commit 310a4d9

Browse files
committed
Fix DiscreteUniform dropping degenerate dimension
1 parent e07eea7 commit 310a4d9

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

pymc/distributions/discrete.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from aesara.tensor.random.basic import (
2020
RandomVariable,
21+
ScipyRandomVariable,
2122
bernoulli,
2223
betabinom,
2324
binomial,
@@ -1117,15 +1118,15 @@ def logcdf(value, good, bad, n):
11171118
)
11181119

11191120

1120-
class DiscreteUniformRV(RandomVariable):
1121+
class DiscreteUniformRV(ScipyRandomVariable):
11211122
name = "discrete_uniform"
11221123
ndim_supp = 0
11231124
ndims_params = [0, 0]
11241125
dtype = "int64"
11251126
_print_name = ("DiscreteUniform", "\\operatorname{DiscreteUniform}")
11261127

11271128
@classmethod
1128-
def rng_fn(cls, rng, lower, upper, size=None):
1129+
def rng_fn_scipy(cls, rng, lower, upper, size=None):
11291130
return stats.randint.rvs(lower, upper + 1, size=size, random_state=rng)
11301131

11311132

pymc/tests/distributions/test_discrete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,10 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng):
10421042
"check_rv_size",
10431043
]
10441044

1045+
def test_implied_degenerate_shape(self):
1046+
x = pm.DiscreteUniform.dist(0, [1])
1047+
assert x.eval().shape == (1,)
1048+
10451049

10461050
class TestDiracDelta(BaseTestDistributionRandom):
10471051
def diracdelta_rng_fn(self, size, c):

0 commit comments

Comments
 (0)