Skip to content

Commit 0d5203c

Browse files
lucianopazaloctavodia
authored andcommitted
Fix Categorical.logp with take_along_axis (#3572)
* Added failing tests * Fixed Categorical.logp using take_along_axis. Added tests for the theanof * Added release notes * Account for ndim mismatch in Categorical logp
1 parent 8b3ecf9 commit 0d5203c

File tree

5 files changed

+330
-11
lines changed

5 files changed

+330
-11
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
### Maintenance
1414
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)
1515
- Parallelization of population steppers (`DEMetropolis`) is now set via the `cores` argument. ([#3559](https://github.com/pymc-devs/pymc3/pull/3559))
16+
- Fixed a bug in `Categorical.logp`. In the case of multidimensional `p`'s, the indexing was done wrong leading to incorrectly shaped tensors that consumed `O(n**2)` memory instead of `O(n)`. This fixes issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535)
17+
- Fixed a defect in `OrderedLogistic.__init__` that unnecessarily increased the dimensionality of the underlying `p`. Related to issue issue [#3535](https://github.com/pymc-devs/pymc3/issues/3535) but was not the true cause of it.
1618
- SMC: stabilize covariance matrix [3573](https://github.com/pymc-devs/pymc3/pull/3573)
1719
- SMC is no longer a step method of `pm.sample` now it should be called using `pm.sample_smc` [3579](https://github.com/pymc-devs/pymc3/pull/3579)
1820
- Now uses `multiprocessong` rather than `psutil` to count CPUs, which results in reliable core counts on Chromebooks.

pymc3/distributions/discrete.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .distribution import Discrete, draw_values, generate_samples
99
from .shape_utils import broadcast_distribution_samples
1010
from pymc3.math import tround, sigmoid, logaddexp, logit, log1pexp
11-
from ..theanof import floatX, intX
11+
from ..theanof import floatX, intX, take_along_axis
1212

1313

1414
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'DiscreteWeibull',
@@ -997,8 +997,21 @@ def logp(self, value):
997997
p = p_ / tt.sum(p_, axis=-1, keepdims=True)
998998

999999
if p.ndim > 1:
1000+
if p.ndim > value_clip.ndim:
1001+
value_clip = tt.shape_padleft(
1002+
value_clip, p_.ndim - value_clip.ndim
1003+
)
1004+
elif p.ndim < value_clip.ndim:
1005+
p = tt.shape_padleft(
1006+
p, value_clip.ndim - p_.ndim
1007+
)
10001008
pattern = (p.ndim - 1,) + tuple(range(p.ndim - 1))
1001-
a = tt.log(p.dimshuffle(pattern)[value_clip])
1009+
a = tt.log(
1010+
take_along_axis(
1011+
p.dimshuffle(pattern),
1012+
value_clip,
1013+
)
1014+
)
10021015
else:
10031016
a = tt.log(p[value_clip])
10041017

@@ -1571,13 +1584,13 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
15711584
self.eta = tt.as_tensor_variable(floatX(eta))
15721585
self.cutpoints = tt.as_tensor_variable(cutpoints)
15731586

1574-
pa = sigmoid(tt.shape_padleft(self.cutpoints) - tt.shape_padright(self.eta))
1587+
pa = sigmoid(self.cutpoints - tt.shape_padright(self.eta))
15751588
p_cum = tt.concatenate([
1576-
tt.zeros_like(tt.shape_padright(pa[:, 0])),
1589+
tt.zeros_like(tt.shape_padright(pa[..., 0])),
15771590
pa,
1578-
tt.ones_like(tt.shape_padright(pa[:, 0]))
1591+
tt.ones_like(tt.shape_padright(pa[..., 0]))
15791592
], axis=-1)
1580-
p = p_cum[:, 1:] - p_cum[:, :-1]
1593+
p = p_cum[..., 1:] - p_cum[..., :-1]
15811594

15821595
super().__init__(p=p, *args, **kwargs)
15831596

pymc3/tests/test_distributions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,3 +1335,35 @@ def test_discrete_trafo():
13351335
with pytest.raises(ValueError) as err:
13361336
Binomial('a', n=5, p=0.5, transform='log')
13371337
err.match('Transformations for discrete distributions')
1338+
1339+
1340+
@pytest.mark.parametrize("shape", [tuple(), (1,), (3, 1), (3, 2)], ids=str)
1341+
def test_orderedlogistic_dimensions(shape):
1342+
# Test for issue #3535
1343+
loge = np.log10(np.exp(1))
1344+
size = 7
1345+
p = np.ones(shape + (10,)) / 10
1346+
cutpoints = np.tile(logit(np.linspace(0, 1, 11)[1:-1]), shape + (1,))
1347+
obs = np.random.randint(0, 1, size=(size,) + shape)
1348+
with Model():
1349+
ol = OrderedLogistic(
1350+
"ol",
1351+
eta=np.zeros(shape),
1352+
cutpoints=cutpoints,
1353+
shape=shape,
1354+
observed=obs
1355+
)
1356+
c = Categorical(
1357+
"c",
1358+
p=p,
1359+
shape=shape,
1360+
observed=obs
1361+
)
1362+
ologp = ol.logp({"ol": 1}) * loge
1363+
clogp = c.logp({"c": 1}) * loge
1364+
expected = -np.prod((size,) + shape)
1365+
1366+
assert c.distribution.p.ndim == (len(shape) + 1)
1367+
assert np.allclose(clogp, expected)
1368+
assert ol.distribution.p.ndim == (len(shape) + 1)
1369+
assert np.allclose(ologp, expected)

pymc3/tests/test_theanof.py

Lines changed: 217 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,61 @@
11
import collections
2-
32
import pytest
4-
from theano import theano
3+
from itertools import product
4+
5+
from theano import theano, tensor as tt
6+
import numpy as np
7+
8+
from pymc3.theanof import set_theano_conf, take_along_axis, _conversion_map
9+
from pymc3.vartypes import int_types
10+
11+
12+
FLOATX = str(theano.config.floatX)
13+
INTX = str(_conversion_map[FLOATX])
14+
15+
16+
def _make_along_axis_idx(arr_shape, indices, axis):
17+
# compute dimensions to iterate over
18+
if str(indices.dtype) not in int_types:
19+
raise IndexError('`indices` must be an integer array')
20+
shape_ones = (1,) * indices.ndim
21+
dest_dims = list(range(axis)) + [None] + list(range(axis+1, indices.ndim))
522

6-
from pymc3.theanof import set_theano_conf
23+
# build a fancy index, consisting of orthogonal aranges, with the
24+
# requested index inserted at the right location
25+
fancy_index = []
26+
for dim, n in zip(dest_dims, arr_shape):
27+
if dim is None:
28+
fancy_index.append(indices)
29+
else:
30+
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
31+
fancy_index.append(np.arange(n).reshape(ind_shape))
32+
33+
return tuple(fancy_index)
34+
35+
36+
if hasattr(np, "take_along_axis"):
37+
np_take_along_axis = np.take_along_axis
38+
else:
39+
def np_take_along_axis(arr, indices, axis):
40+
if arr.shape[axis] <= 32:
41+
# We can safely test with numpy's choose
42+
arr = np.moveaxis(arr, axis, 0)
43+
indices = np.moveaxis(indices, axis, 0)
44+
out = np.choose(indices, arr)
45+
return np.moveaxis(out, 0, axis)
46+
else:
47+
# numpy's choose cannot handle such a large axis so we
48+
# just use the implementation of take_along_axis. This is kind of
49+
# cheating because our implementation is the same as the one below
50+
if axis < 0:
51+
_axis = arr.ndim + axis
52+
else:
53+
_axis = axis
54+
if _axis < 0 or _axis >= arr.ndim:
55+
raise ValueError(
56+
"Supplied axis {} is out of bounds".format(axis)
57+
)
58+
return arr[_make_along_axis_idx(arr.shape, indices, _axis)]
759

860

961
class TestSetTheanoConfig:
@@ -26,3 +78,165 @@ def test_restore(self):
2678
assert conf == {'compute_test_value': 'off'}
2779
conf = set_theano_conf(conf)
2880
assert conf == {'compute_test_value': 'raise'}
81+
82+
83+
class TestTakeAlongAxis():
84+
def setup_class(self):
85+
self.inputs_buffer = dict()
86+
self.output_buffer = dict()
87+
self.func_buffer = dict()
88+
89+
def _input_tensors(self, shape):
90+
ndim = len(shape)
91+
arr = tt.TensorType(FLOATX, [False] * ndim)("arr")
92+
indices = tt.TensorType(INTX, [False] * ndim)("indices")
93+
arr.tag.test_value = np.zeros(shape, dtype=FLOATX)
94+
indices.tag.test_value = np.zeros(shape, dtype=INTX)
95+
return arr, indices
96+
97+
def get_input_tensors(self, shape):
98+
ndim = len(shape)
99+
try:
100+
return self.inputs_buffer[ndim]
101+
except KeyError:
102+
arr, indices = self._input_tensors(shape)
103+
self.inputs_buffer[ndim] = arr, indices
104+
return arr, indices
105+
106+
def _output_tensor(self, arr, indices, axis):
107+
return take_along_axis(arr, indices, axis)
108+
109+
def get_output_tensors(self, shape, axis):
110+
ndim = len(shape)
111+
try:
112+
return self.output_buffer[(ndim, axis)]
113+
except KeyError:
114+
arr, indices = self.get_input_tensors(shape)
115+
out = self._output_tensor(arr, indices, axis)
116+
self.output_buffer[(ndim, axis)] = out
117+
return out
118+
119+
def _function(self, arr, indices, out):
120+
return theano.function([arr, indices], [out])
121+
122+
def get_function(self, shape, axis):
123+
ndim = len(shape)
124+
try:
125+
return self.func_buffer[(ndim, axis)]
126+
except KeyError:
127+
arr, indices = self.get_input_tensors(shape)
128+
out = self.get_output_tensors(shape, axis)
129+
func = self._function(arr, indices, out)
130+
self.func_buffer[(ndim, axis)] = func
131+
return func
132+
133+
@staticmethod
134+
def get_input_values(shape, axis, samples):
135+
arr = np.random.randn(*shape).astype(FLOATX)
136+
size = list(shape)
137+
size[axis] = samples
138+
size = tuple(size)
139+
indices = np.random.randint(
140+
low=0, high=shape[axis], size=size, dtype=INTX
141+
)
142+
return arr, indices
143+
144+
@pytest.mark.parametrize(
145+
["shape", "axis", "samples"],
146+
product(
147+
[
148+
(1,),
149+
(3,),
150+
(3, 1),
151+
(3, 2),
152+
(1, 1),
153+
(1, 2),
154+
(40, 40), # choose fails here
155+
(5, 1, 1),
156+
(5, 1, 2),
157+
(5, 3, 1),
158+
(5, 3, 2),
159+
],
160+
[0, -1],
161+
[1, 10],
162+
),
163+
ids=str,
164+
)
165+
def test_take_along_axis(self, shape, axis, samples):
166+
arr, indices = self.get_input_values(shape, axis, samples)
167+
func = self.get_function(shape, axis)
168+
assert np.allclose(
169+
np_take_along_axis(arr, indices, axis=axis),
170+
func(arr, indices)[0]
171+
)
172+
173+
@pytest.mark.parametrize(
174+
["shape", "axis", "samples"],
175+
product(
176+
[
177+
(1,),
178+
(3,),
179+
(3, 1),
180+
(3, 2),
181+
(1, 1),
182+
(1, 2),
183+
(40, 40), # choose fails here
184+
(5, 1, 1),
185+
(5, 1, 2),
186+
(5, 3, 1),
187+
(5, 3, 2),
188+
],
189+
[0, -1],
190+
[1, 10],
191+
),
192+
ids=str,
193+
)
194+
def test_take_along_axis_grad(self, shape, axis, samples):
195+
if axis < 0:
196+
_axis = len(shape) + axis
197+
else:
198+
_axis = axis
199+
# Setup the theano function
200+
t_arr, t_indices = self.get_input_tensors(shape)
201+
t_out2 = theano.grad(
202+
tt.sum(self._output_tensor(t_arr**2, t_indices, axis)),
203+
t_arr,
204+
)
205+
func = theano.function([t_arr, t_indices], [t_out2])
206+
207+
# Test that the gradient gives the same output as what is expected
208+
arr, indices = self.get_input_values(shape, axis, samples)
209+
expected_grad = np.zeros_like(arr)
210+
slicer = [slice(None)] * len(shape)
211+
for i in range(indices.shape[axis]):
212+
slicer[axis] = i
213+
inds = indices[slicer].reshape(
214+
shape[:_axis] + (1,) + shape[_axis + 1:]
215+
)
216+
inds = _make_along_axis_idx(shape, inds, _axis)
217+
expected_grad[inds] += 1
218+
expected_grad *= 2 * arr
219+
out = func(arr, indices)[0]
220+
assert np.allclose(out, expected_grad)
221+
222+
@pytest.mark.parametrize("axis", [-4, 4], ids=str)
223+
def test_axis_failure(self, axis):
224+
arr, indices = self.get_input_tensors((3, 1))
225+
with pytest.raises(ValueError):
226+
take_along_axis(arr, indices, axis=axis)
227+
228+
def test_ndim_failure(self):
229+
arr = tt.TensorType(FLOATX, [False] * 3)("arr")
230+
indices = tt.TensorType(INTX, [False] * 2)("indices")
231+
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=FLOATX)
232+
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=INTX)
233+
with pytest.raises(ValueError):
234+
take_along_axis(arr, indices)
235+
236+
def test_dtype_failure(self):
237+
arr = tt.TensorType(FLOATX, [False] * 3)("arr")
238+
indices = tt.TensorType(FLOATX, [False] * 3)("indices")
239+
arr.tag.test_value = np.zeros((1,) * arr.ndim, dtype=FLOATX)
240+
indices.tag.test_value = np.zeros((1,) * indices.ndim, dtype=FLOATX)
241+
with pytest.raises(IndexError):
242+
take_along_axis(arr, indices)

0 commit comments

Comments
 (0)