Skip to content

Commit 81d18af

Browse files
committed
Handle multivariate transforms of univariate distributions correctly
Details: * Fix broadcasting bug in univariate Ordered and SumTo1 transform logp, and add explicitly check when building the graph * Raise if univariate transform is applied to multivariate distribution * Checks and logp reduction are applied even when jacobian is not used
1 parent d659848 commit 81d18af

File tree

6 files changed

+185
-244
lines changed

6 files changed

+185
-244
lines changed

docs/source/api/distributions/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Transform instances are the entities that should be used in the
1919
logodds
2020
simplex
2121
sum_to_1
22+
ordered
2223

2324

2425
Specific Transform Classes

pymc/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ class OrderedLogistic:
12591259
# Ordered logistic regression
12601260
with pm.Model() as model:
12611261
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
1262-
transform=pm.distributions.transforms.univariate_ordered)
1262+
transform=pm.distributions.transforms.ordered)
12631263
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
12641264
idata = pm.sample()
12651265

pymc/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ class NormalMixture:
539539
mu=data.mean(),
540540
sigma=10,
541541
shape=n_components,
542-
transform=pm.distributions.transforms.univariate_ordered,
542+
transform=pm.distributions.transforms.ordered,
543543
initval=[1, 2, 3],
544544
)
545545
σ = pm.HalfNormal("σ", sigma=10, shape=n_components)

pymc/distributions/transforms.py

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from functools import singledispatch
1517

1618
import numpy as np
@@ -39,19 +41,28 @@
3941
"logodds",
4042
"Interval",
4143
"log_exp_m1",
42-
"univariate_ordered",
43-
"multivariate_ordered",
44+
"ordered",
4445
"log",
4546
"sum_to_1",
46-
"univariate_sum_to_1",
47-
"multivariate_sum_to_1",
4847
"circular",
4948
"CholeskyCovPacked",
5049
"Chain",
5150
"ZeroSumTransform",
5251
]
5352

5453

54+
def __getattr__(name):
55+
if name in ("univariate_ordered", "multivariate_ordered"):
56+
warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning)
57+
return ordered
58+
59+
if name in ("univariate_sum_to_1, multivariate_sum_to_1"):
60+
warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning)
61+
return sum_to_1
62+
63+
raise AttributeError(f"module {__name__} has no attribute {name}")
64+
65+
5566
@singledispatch
5667
def _default_transform(op: Op, rv: TensorVariable):
5768
"""Return default transform for a given Distribution `Op`"""
@@ -79,13 +90,9 @@ def log_jac_det(self, value, *inputs):
7990
class Ordered(RVTransform):
8091
name = "ordered"
8192

82-
def __init__(self, ndim_supp=0):
83-
if ndim_supp > 1:
84-
raise ValueError(
85-
f"For Ordered transformation number of core dimensions"
86-
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
87-
)
88-
self.ndim_supp = ndim_supp
93+
def __init__(self, ndim_supp=None):
94+
if ndim_supp is not None:
95+
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
8996

9097
def backward(self, value, *inputs):
9198
x = pt.zeros(value.shape)
@@ -100,10 +107,7 @@ def forward(self, value, *inputs):
100107
return y
101108

102109
def log_jac_det(self, value, *inputs):
103-
if self.ndim_supp == 0:
104-
return pt.sum(value[..., 1:], axis=-1, keepdims=True)
105-
else:
106-
return pt.sum(value[..., 1:], axis=-1)
110+
return pt.sum(value[..., 1:], axis=-1)
107111

108112

109113
class SumTo1(RVTransform):
@@ -114,13 +118,9 @@ class SumTo1(RVTransform):
114118

115119
name = "sumto1"
116120

117-
def __init__(self, ndim_supp=0):
118-
if ndim_supp > 1:
119-
raise ValueError(
120-
f"For SumTo1 transformation number of core dimensions"
121-
f"(ndim_supp) must not exceed 1 but is {ndim_supp}"
122-
)
123-
self.ndim_supp = ndim_supp
121+
def __init__(self, ndim_supp=None):
122+
if ndim_supp is not None:
123+
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
124124

125125
def backward(self, value, *inputs):
126126
remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)
@@ -131,10 +131,7 @@ def forward(self, value, *inputs):
131131

132132
def log_jac_det(self, value, *inputs):
133133
y = pt.zeros(value.shape)
134-
if self.ndim_supp == 0:
135-
return pt.sum(y, axis=-1, keepdims=True)
136-
else:
137-
return pt.sum(y, axis=-1)
134+
return pt.sum(y, axis=-1)
138135

139136

140137
class CholeskyCovPacked(RVTransform):
@@ -359,38 +356,21 @@ def extend_axis_rev(array, axis):
359356
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
360357
for use in the ``transform`` argument of a random variable."""
361358

362-
univariate_ordered = Ordered(ndim_supp=0)
363-
univariate_ordered.__doc__ = """
359+
# Deprecated
360+
ordered = Ordered()
361+
ordered.__doc__ = """
364362
Instantiation of :class:`pymc.distributions.transforms.Ordered`
365-
for use in the ``transform`` argument of a univariate random variable."""
366-
367-
multivariate_ordered = Ordered(ndim_supp=1)
368-
multivariate_ordered.__doc__ = """
369-
Instantiation of :class:`pymc.distributions.transforms.Ordered`
370-
for use in the ``transform`` argument of a multivariate random variable."""
363+
for use in the ``transform`` argument of a random variable."""
371364

372365
log = LogTransform()
373366
log.__doc__ = """
374367
Instantiation of :class:`pymc.logprob.transforms.LogTransform`
375368
for use in the ``transform`` argument of a random variable."""
376369

377-
univariate_sum_to_1 = SumTo1(ndim_supp=0)
378-
univariate_sum_to_1.__doc__ = """
379-
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
380-
for use in the ``transform`` argument of a univariate random variable."""
381-
382-
multivariate_sum_to_1 = SumTo1(ndim_supp=1)
383-
multivariate_sum_to_1.__doc__ = """
384-
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
385-
for use in the ``transform`` argument of a multivariate random variable."""
386-
387-
# backwards compatibility
388-
sum_to_1 = SumTo1(ndim_supp=1)
370+
sum_to_1 = SumTo1()
389371
sum_to_1.__doc__ = """
390372
Instantiation of :class:`pymc.distributions.transforms.SumTo1`
391-
for use in the ``transform`` argument of a random variable.
392-
This instantiation is for backwards compatibility only.
393-
Please use `univariate_sum_to_1` or `multivariate_sum_to_1` instead."""
373+
for use in the ``transform`` argument of a random variable."""
394374

395375
circular = CircularTransform()
396376
circular.__doc__ = """

pymc/logprob/transforms.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
195195
phi_inv = self.backward(value, *inputs)
196196
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0]))))
197197

198+
def __str__(self):
199+
return f"{self.__class__.__name__}"
200+
198201

199202
@node_rewriter(tracks=None)
200203
def transform_values(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
@@ -1219,22 +1222,46 @@ def transformed_logprob(op, values, *inputs, use_jacobian=True, **kwargs):
12191222
if not isinstance(logprobs, Sequence):
12201223
logprobs = [logprobs]
12211224

1222-
if use_jacobian:
1223-
assert len(values) == len(logprobs) == len(op.transforms)
1224-
logprobs_jac = []
1225-
for value, transform, logp in zip(values, op.transforms, logprobs):
1226-
if transform is None:
1227-
logprobs_jac.append(logp)
1228-
continue
1229-
assert isinstance(value.owner.op, TransformedVariable)
1230-
original_forward_value = value.owner.inputs[1]
1231-
jacobian = transform.log_jac_det(original_forward_value, *inputs).copy()
1225+
# Handle jacobian
1226+
assert len(values) == len(logprobs) == len(op.transforms)
1227+
logprobs_jac = []
1228+
for value, transform, logp in zip(values, op.transforms, logprobs):
1229+
if transform is None:
1230+
logprobs_jac.append(logp)
1231+
continue
1232+
1233+
assert isinstance(value.owner.op, TransformedVariable)
1234+
original_forward_value = value.owner.inputs[1]
1235+
log_jac_det = transform.log_jac_det(original_forward_value, *inputs).copy()
1236+
# The jacobian determinant has less dims than the logp
1237+
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
1238+
# In this case we have to reduce the last logp dimensions, as they are no longer independent
1239+
if log_jac_det.ndim < logp.ndim:
1240+
diff_ndims = logp.ndim - log_jac_det.ndim
1241+
logp = logp.sum(axis=np.arange(-diff_ndims, 0))
1242+
# This case is sometimes, but not always, trivial to accomodate depending on the "space rank" of the
1243+
# multivariate distribution. See https://proceedings.mlr.press/v130/radul21a.html
1244+
elif log_jac_det.ndim > logp.ndim:
1245+
raise NotImplementedError(
1246+
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
1247+
)
1248+
else:
1249+
# Check there is no broadcasting between logp and jacobian
1250+
if logp.type.broadcastable != log_jac_det.type.broadcastable:
1251+
raise ValueError(
1252+
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
1253+
"There is a bug in the implementation of either one."
1254+
)
1255+
1256+
if use_jacobian:
12321257
if value.name:
1233-
jacobian.name = f"{value.name}_jacobian"
1234-
logprobs_jac.append(logp + jacobian)
1235-
logprobs = logprobs_jac
1258+
log_jac_det.name = f"{value.name}_jacobian"
1259+
logprobs_jac.append(logp + log_jac_det)
1260+
else:
1261+
# We still want to use the reduced logp, even though the jacobian isn't included
1262+
logprobs_jac.append(logp)
12361263

1237-
return logprobs
1264+
return logprobs_jac
12381265

12391266
new_op = copy(rv_op)
12401267
new_op.__class__ = new_op_type

0 commit comments

Comments
 (0)