Skip to content

Commit 0402aab

Browse files
authored
Add bound to HyperGeometric logp (resolves #4366) (#4367)
* - Add bound to HyperGeometric logp - Pass unit tests when scipy logpmf returns nan * - Add release-note * - Replace tt.max and tt.min with tt.switch
1 parent 0ec65e5 commit 0402aab

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
2525
- The notebook gallery has been moved to https://github.com/pymc-devs/pymc-examples (see [#4348](https://github.com/pymc-devs/pymc3/pull/4348)).
2626
- `math.logsumexp` now matches `scipy.special.logsumexp` when arrays contain infinite values (see [#4360](https://github.com/pymc-devs/pymc3/pull/4360)).
2727
- Fixed mathematical formulation in `MvStudentT` random method. (see [#4359](https://github.com/pymc-devs/pymc3/pull/4359))
28+
- Fix issue in `logp` method of `HyperGeometric`. It now returns `-inf` for invalid parameters (see [4367](https://github.com/pymc-devs/pymc3/pull/4367))
2829

2930
## PyMC3 3.10.0 (7 December 2020)
3031

pymc3/distributions/discrete.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,10 @@ def logp(self, value):
930930
- betaln(n - value + 1, bad - n + value + 1)
931931
- betaln(tot + 1, 1)
932932
)
933-
return result
933+
# value in [max(0, n - N + k), min(k, n)]
934+
lower = tt.switch(tt.gt(n - N + k, 0), n - N + k, 0)
935+
upper = tt.switch(tt.lt(k, n), k, n)
936+
return bound(result, lower <= value, value <= upper)
934937

935938

936939
class DiscreteUniform(Discrete):

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,11 +805,15 @@ def test_geometric(self):
805805
)
806806

807807
def test_hypergeometric(self):
808+
def modified_scipy_hypergeom_logpmf(value, N, k, n):
809+
original_res = sp.hypergeom.logpmf(value, N, k, n)
810+
return original_res if not np.isnan(original_res) else -np.inf
811+
808812
self.pymc3_matches_scipy(
809813
HyperGeometric,
810814
Nat,
811815
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
812-
lambda value, N, k, n: sp.hypergeom.logpmf(value, N, k, n),
816+
lambda value, N, k, n: modified_scipy_hypergeom_logpmf(value, N, k, n),
813817
)
814818

815819
def test_negative_binomial(self):

0 commit comments

Comments
 (0)