Skip to content

Commit c613106

Browse files
lucianopaztwiecki
authored andcommitted
Changed Categorical to work with multidim p at the logp level reloaded (#3386)
* Changed Categorical to work with multidim p at the logp level. * Fixed problems with OrderedLogistic. * Use np.moveaxis instead of transposing. Also added some more tests. * Changed np.moveaxis to dimshuffle. Removed isinstance conditions that were imposible. * Fixed lint error * Force a normalized p in logp * Fixed lint error * Edited release notes. Categorical.p can have zeros, and the normalization of p is always carried out * Fixed orderedlogistic test logpdf to check for ps>=0 and not just ps>0
1 parent 6f7fafa commit c613106

File tree

3 files changed

+7
-20
lines changed

3 files changed

+7
-20
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
- `dist_math.random_choice` now handles nd-arrays of category probabilities, and also handles sizes that are not `None`. Also removed unused `k` kwarg from `dist_math.random_choice`.
2727
- Changed `Categorical.mode` to preserve all the dimensions of `p` except the last one, which encodes each category's probability.
2828
- Changed initialization of `Categorical.p`. `p` is now normalized to sum to `1` inside `logp` and `random`, but not during initialization. This could hide negative values supplied to `p` as mentioned in #2082.
29-
- To be able to test for negative `p` values supplied to `Categorical`, `Categorical.logp` was changed to check for `sum(self.p, axis=-1) == 1` only if `self.p` is not a `Number`, `np.ndarray`, `TensorConstant` or `SharedVariable`. These cases are automatically normalized to sum to `1`. The other condition may originate from a `step_method` proposal, where `self.p` tensor's value may be set, but must sum to 1 nevertheless. This may break old code which intialized `p` with a theano expression and relied on the default normalization to get it to sum to 1. `Categorical.logp` now also checks that the used `p` has values lower than 1.
29+
- `Categorical` now accepts elements of `p` equal to `0`. `logp` will return `-inf` if there are `values` that index to the zero probability categories.
3030

3131
### Deprecations
3232

pymc3/distributions/discrete.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import numbers
21
import numpy as np
3-
import theano
42
import theano.tensor as tt
53
from scipy import stats
64
import warnings
@@ -736,27 +734,16 @@ def logp(self, value):
736734
# Clip values before using them for indexing
737735
value_clip = tt.clip(value, 0, k - 1)
738736

739-
# We must only check that the values sum to 1 if p comes from a
740-
# tensor variable, i.e. when p is a step_method proposal. In the other
741-
# cases we normalize ourselves
742-
if not isinstance(p_, (numbers.Number,
743-
np.ndarray,
744-
tt.TensorConstant,
745-
tt.sharedvar.SharedVariable)):
746-
sumto1 = theano.gradient.zero_grad(
747-
tt.le(abs(tt.sum(p_, axis=-1) - 1), 1e-5))
748-
p = p_
749-
else:
750-
p = p_ / tt.sum(p_, axis=-1, keepdims=True)
751-
sumto1 = True
737+
p = p_ / tt.sum(p_, axis=-1, keepdims=True)
752738

753739
if p.ndim > 1:
754-
a = tt.log(np.moveaxis(p, -1, 0)[value_clip])
740+
pattern = (p.ndim - 1,) + tuple(range(p.ndim - 1))
741+
a = tt.log(p.dimshuffle(pattern)[value_clip])
755742
else:
756743
a = tt.log(p[value_clip])
757744

758-
return bound(a, value >= 0, value <= (k - 1), sumto1,
759-
tt.all(p_ > 0, axis=-1), tt.all(p <= 1, axis=-1))
745+
return bound(a, value >= 0, value <= (k - 1),
746+
tt.all(p_ >= 0, axis=-1), tt.all(p <= 1, axis=-1))
760747

761748
def _repr_latex_(self, name=None, dist=None):
762749
if dist is None:

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def orderedlogistic_logpdf(value, eta, cutpoints):
349349
ps = np.array([invlogit(eta - cc) - invlogit(eta - cc1)
350350
for cc, cc1 in zip(c[:-1], c[1:])])
351351
p = ps[value]
352-
return np.where(np.all(ps > 0), np.log(p), -np.inf)
352+
return np.where(np.all(ps >= 0), np.log(p), -np.inf)
353353

354354
class Simplex:
355355
def __init__(self, n):

0 commit comments

Comments
 (0)