Skip to content

Commit ce543da

Browse files
committed
Allow batched parameters in MvNormal and MvStudentT distributions
1 parent a3ec9a5 commit ce543da

File tree

4 files changed

+239
-181
lines changed

4 files changed

+239
-181
lines changed

pymc/distributions/multivariate.py

Lines changed: 40 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -115,40 +115,37 @@ def simplex_cont_transform(op, rv):
115115

116116

117117
def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
118-
if chol is not None and not lower:
119-
chol = chol.T
120-
121118
if len([i for i in [tau, cov, chol] if i is not None]) != 1:
122119
raise ValueError("Incompatible parameterization. Specify exactly one of tau, cov, or chol.")
123120

124121
if cov is not None:
125122
cov = pt.as_tensor_variable(cov)
126-
if cov.ndim != 2:
127-
raise ValueError("cov must be two dimensional.")
123+
if cov.ndim < 2:
124+
raise ValueError("cov must be at least two dimensional.")
128125
elif tau is not None:
129126
tau = pt.as_tensor_variable(tau)
130-
if tau.ndim != 2:
131-
raise ValueError("tau must be two dimensional.")
132-
# TODO: What's the correct order/approach (in the non-square case)?
133-
# `pytensor.tensor.nlinalg.tensorinv`?
127+
if tau.ndim < 2:
128+
raise ValueError("tau must be at least two dimensional.")
134129
cov = matrix_inverse(tau)
135130
else:
136-
# TODO: What's the correct order/approach (in the non-square case)?
137131
chol = pt.as_tensor_variable(chol)
138-
if chol.ndim != 2:
139-
raise ValueError("chol must be two dimensional.")
132+
if chol.ndim < 2:
133+
raise ValueError("chol must be at least two dimensional.")
134+
135+
if not lower:
136+
chol = pt.swapaxes(chol, -1, -2)
140137

141138
# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
142139
chol.tag.lower_triangular = True
143-
cov = chol.dot(chol.T)
140+
cov = pt.matmul(chol, pt.swapaxes(chol, -1, -2))
144141

145142
return cov
146143

147144

148-
def quaddist_parse(value, mu, cov, mat_type="cov"):
145+
def quaddist_chol(value, mu, cov):
149146
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
150-
if value.ndim > 2 or value.ndim == 0:
151-
raise ValueError("Invalid dimension for value: %s" % value.ndim)
147+
if value.ndim == 0:
148+
raise ValueError("Value can't be a scalar")
152149
if value.ndim == 1:
153150
onedim = True
154151
value = value[None, :]
@@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
157154

158155
delta = value - mu
159156
chol_cov = nan_lower_cholesky(cov)
160-
if mat_type != "tau":
161-
dist, logdet, ok = quaddist_chol(delta, chol_cov)
162-
else:
163-
dist, logdet, ok = quaddist_tau(delta, chol_cov)
164-
if onedim:
165-
return dist[0], logdet, ok
166-
167-
return dist, logdet, ok
168-
169157

170-
def quaddist_chol(delta, chol_mat):
171-
diag = pt.diag(chol_mat)
158+
diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)
172159
# Check if the covariance matrix is positive definite.
173-
ok = pt.all(diag > 0)
160+
ok = pt.all(diag > 0, axis=-1)
174161
# If not, replace the diagonal. We return -inf later, but
175162
# need to prevent solve_lower from throwing an exception.
176-
chol_cov = pt.switch(ok, chol_mat, 1)
177-
178-
delta_trans = solve_lower(chol_cov, delta.T).T
163+
chol_cov = pt.switch(ok[..., None, None], chol_cov, 1)
164+
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
179165
quaddist = (delta_trans**2).sum(axis=-1)
180-
logdet = pt.sum(pt.log(diag))
181-
return quaddist, logdet, ok
182-
183-
184-
def quaddist_tau(delta, chol_mat):
185-
diag = pt.nlinalg.diag(chol_mat)
186-
# Check if the precision matrix is positive definite.
187-
ok = pt.all(diag > 0)
188-
# If not, replace the diagonal. We return -inf later, but
189-
# need to prevent solve_lower from throwing an exception.
190-
chol_tau = pt.switch(ok, chol_mat, 1)
166+
logdet = pt.log(diag).sum(axis=-1)
191167

192-
delta_trans = pt.dot(delta, chol_tau)
193-
quaddist = (delta_trans**2).sum(axis=-1)
194-
logdet = -pt.sum(pt.log(diag))
195-
return quaddist, logdet, ok
168+
if onedim:
169+
return quaddist[0], logdet, ok
170+
else:
171+
return quaddist, logdet, ok
196172

197173

198174
class MvNormal(Continuous):
@@ -266,10 +242,11 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
266242
mu = pt.as_tensor_variable(mu)
267243
cov = quaddist_matrix(cov, chol, tau, lower)
268244
# PyTensor is stricter about the shape of mu, than PyMC used to be
269-
mu = pt.broadcast_arrays(mu, cov[..., -1])[0]
245+
mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
270246
return super().dist([mu, cov], **kwargs)
271247

272248
def moment(rv, size, mu, cov):
249+
# mu is broadcasted to the potential length of cov in `dist`
273250
moment = mu
274251
if not rv_size_is_none(size):
275252
moment_size = pt.concatenate([size, [mu.shape[-1]]])
@@ -290,7 +267,7 @@ def logp(value, mu, cov):
290267
-------
291268
TensorVariable
292269
"""
293-
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
270+
quaddist, logdet, ok = quaddist_chol(value, mu, cov)
294271
k = floatX(value.shape[-1])
295272
norm = -0.5 * k * pm.floatX(np.log(2 * np.pi))
296273
return check_parameters(
@@ -307,22 +284,6 @@ class MvStudentTRV(RandomVariable):
307284
dtype = "floatX"
308285
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")
309286

310-
def make_node(self, rng, size, dtype, nu, mu, cov):
311-
nu = pt.as_tensor_variable(nu)
312-
if not nu.ndim == 0:
313-
raise ValueError("nu must be a scalar (ndim=0).")
314-
315-
return super().make_node(rng, size, dtype, nu, mu, cov)
316-
317-
def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
318-
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype
319-
320-
if mu is None:
321-
mu = np.array([0.0], dtype=dtype)
322-
if cov is None:
323-
cov = np.array([[1.0]], dtype=dtype)
324-
return super().__call__(nu, mu, cov, size=size, **kwargs)
325-
326287
def _supp_shape_from_params(self, dist_params, param_shapes=None):
327288
return supp_shape_from_ref_param_shape(
328289
ndim_supp=self.ndim_supp,
@@ -333,14 +294,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
333294

334295
@classmethod
335296
def rng_fn(cls, rng, nu, mu, cov, size):
297+
if size is None:
298+
# When size is implicit, we need to broadcast parameters correctly,
299+
# so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
300+
# nu broadcasts mu and cov
301+
if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2):
302+
_, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)
303+
# nu is broadcasted by either mu or cov
304+
elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2):
305+
nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)
306+
336307
mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)
337308

338309
# Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
339310
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None]
340311

341-
if size:
342-
mu = np.broadcast_to(mu, size + (mu.shape[-1],))
343-
344312
return (mv_samples / chi2_samples) + mu
345313

346314

@@ -390,7 +358,7 @@ class MvStudentT(Continuous):
390358
rv_op = mv_studentt
391359

392360
@classmethod
393-
def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=True, **kwargs):
361+
def dist(cls, nu, *, Sigma=None, mu, scale=None, tau=None, chol=None, lower=True, **kwargs):
394362
cov = kwargs.pop("cov", None)
395363
if cov is not None:
396364
warnings.warn(
@@ -407,11 +375,13 @@ def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=Tr
407375
mu = pt.as_tensor_variable(floatX(mu))
408376
scale = quaddist_matrix(scale, chol, tau, lower)
409377
# PyTensor is stricter about the shape of mu, than PyMC used to be
410-
mu = pt.broadcast_arrays(mu, scale[..., -1])[0]
378+
mu, _ = pt.broadcast_arrays(mu, scale[..., -1])
411379

412380
return super().dist([nu, mu, scale], **kwargs)
413381

414382
def moment(rv, size, nu, mu, scale):
383+
# mu is broadcasted to the potential length of scale in `dist`
384+
mu, _ = pt.random.utils.broadcast_params([mu, nu], ndims_params=[1, 0])
415385
moment = mu
416386
if not rv_size_is_none(size):
417387
moment_size = pt.concatenate([size, [mu.shape[-1]]])
@@ -432,7 +402,7 @@ def logp(value, nu, mu, scale):
432402
-------
433403
TensorVariable
434404
"""
435-
quaddist, logdet, ok = quaddist_parse(value, mu, scale)
405+
quaddist, logdet, ok = quaddist_chol(value, mu, scale)
436406
k = floatX(value.shape[-1])
437407

438408
norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * pt.log(nu * np.pi)

tests/distributions/test_mixture.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -333,21 +333,18 @@ def test_list_multivariate_components_deterministic_weights(self, weights, compo
333333
assert not repetitions
334334

335335
# Test logp
336-
# MvNormal logp is currently limited to 2d values
337-
expectation = pytest.raises(ValueError) if mix_eval.ndim > 2 else does_not_raise()
338-
with expectation:
339-
mix_logp_eval = logp(mix, mix_eval).eval()
340-
assert mix_logp_eval.shape == expected_shape[:-1]
341-
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
342-
expected_logp = np.stack(
343-
(
344-
logp(components[0], mix_eval).eval(),
345-
logp(components[1], mix_eval).eval(),
346-
),
347-
axis=-1,
348-
)[bcast_weights == 1]
349-
expected_logp = expected_logp.reshape(expected_shape[:-1])
350-
assert np.allclose(mix_logp_eval, expected_logp)
336+
mix_logp_eval = logp(mix, mix_eval).eval()
337+
assert mix_logp_eval.shape == expected_shape[:-1]
338+
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
339+
expected_logp = np.stack(
340+
(
341+
logp(components[0], mix_eval).eval(),
342+
logp(components[1], mix_eval).eval(),
343+
),
344+
axis=-1,
345+
)[bcast_weights == 1]
346+
expected_logp = expected_logp.reshape(expected_shape[:-1])
347+
assert np.allclose(mix_logp_eval, expected_logp)
351348

352349
def test_component_choice_random(self):
353350
"""Test that mixture choices change over evaluations"""

0 commit comments

Comments
 (0)