Skip to content

Commit fb73bde

Browse files
Suppport for discrete max/min
1 parent 88b4abc commit fb73bde

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

pymc/logprob/order.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,15 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
111111
if base_var.owner.op.dtype.startswith("int"):
112112
if isinstance(base_var.owner.op, RandomVariable):
113113
measurable_max = MeasurableMaxDiscrete(list(axis))
114-
max_rv_node = measurable_max.make_node(base_var)
115-
max_rv = max_rv_node.outputs
116-
117-
return max_rv
114+
else:
115+
return None
118116
else:
119117
measurable_max = MeasurableMax(list(axis))
120-
max_rv_node = measurable_max.make_node(base_var)
121-
max_rv = max_rv_node.outputs
122118

123-
return max_rv
119+
max_rv_node = measurable_max.make_node(base_var)
120+
max_rv = max_rv_node.outputs
121+
122+
return max_rv
124123

125124

126125
measurable_ir_rewrites_db.register(

tests/logprob/test_order.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_min_non_mul_elemwise_fails():
236236

237237
@pytest.mark.parametrize(
238238
"mu, size, value, axis",
239-
[(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)],
239+
[(2, 3, 0.85, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
240240
)
241241
def test_max_discrete(mu, size, value, axis):
242242
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
@@ -247,8 +247,8 @@ def test_max_discrete(mu, size, value, axis):
247247
test_value = value
248248

249249
n = size
250-
exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n
251-
exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n
250+
exp_rv = sp.poisson(mu).cdf(test_value) ** n
251+
exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n
252252

253253
np.testing.assert_allclose(
254254
np.log(exp_rv - exp_rv_prev),

0 commit comments

Comments
 (0)