Skip to content

Commit d7bfa6d

Browse files
brandonwillardtwiecki
authored andcommitted
Make logpt work correctly for nested models and transforms
1 parent 2ffb51b commit d7bfa6d

File tree

8 files changed

+268
-236
lines changed

8 files changed

+268
-236
lines changed

pymc3/distributions/__init__.py

Lines changed: 157 additions & 130 deletions
Large diffs are not rendered by default.

pymc3/distributions/distribution.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from aesara.tensor.random.op import RandomVariable
2828

29-
from pymc3.distributions import _logcdf, _logp, logp_transform
29+
from pymc3.distributions import _logcdf, _logp
3030

3131
if TYPE_CHECKING:
3232
from typing import Optional, Callable
@@ -111,12 +111,12 @@ def logp(op, value, *dist_params, **kwargs):
111111
def logcdf(op, value, *dist_params, **kwargs):
112112
return class_logcdf(value, *dist_params, **kwargs)
113113

114-
class_transform = clsdict.get("transform")
115-
if class_transform:
116-
117-
@logp_transform.register(rv_type)
118-
def transform(op, *args, **kwargs):
119-
return class_transform(*args, **kwargs)
114+
# class_transform = clsdict.get("transform")
115+
# if class_transform:
116+
#
117+
# @logp_transform.register(rv_type)
118+
# def transform(op, *args, **kwargs):
119+
# return class_transform(*args, **kwargs)
120120

121121
# Register the Aesara `RandomVariable` type as a subclass of this
122122
# `Distribution` type.
@@ -328,26 +328,17 @@ def _distr_parameters_for_repr(self):
328328
class Discrete(Distribution):
329329
"""Base class for discrete distributions"""
330330

331-
def __init__(self, shape=(), dtype=None, defaults=("mode",), *args, **kwargs):
332-
if dtype is None:
333-
if aesara.config.floatX == "float32":
334-
dtype = "int16"
335-
else:
336-
dtype = "int64"
337-
if dtype != "int16" and dtype != "int64":
338-
raise TypeError("Discrete classes expect dtype to be int16 or int64.")
331+
def __new__(cls, name, *args, **kwargs):
339332

340-
super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
333+
if kwargs.get("transform", None):
334+
raise ValueError("Transformations for discrete distributions")
335+
336+
return super().__new__(cls, name, *args, **kwargs)
341337

342338

343339
class Continuous(Distribution):
344340
"""Base class for continuous distributions"""
345341

346-
def __init__(self, shape=(), dtype=None, defaults=("median", "mean", "mode"), *args, **kwargs):
347-
if dtype is None:
348-
dtype = aesara.config.floatX
349-
super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
350-
351342

352343
class DensityDist(Distribution):
353344
"""Distribution based on a given log density function.

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ class Dirichlet(Continuous):
388388
rv_op = dirichlet
389389

390390
@classmethod
391-
def dist(cls, a, **kwargs):
391+
def dist(cls, a, transform=transforms.stick_breaking, **kwargs):
392392

393393
a = at.as_tensor_variable(a)
394394
# mean = a / at.sum(a)
@@ -419,15 +419,6 @@ def logp(value, a):
419419
broadcast_conditions=False,
420420
)
421421

422-
def transform(rv_var):
423-
424-
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
425-
# If this variable is just a bunch of scalars/degenerate
426-
# Dirichlets, we can't transform it
427-
return None
428-
429-
return transforms.stick_breaking
430-
431422
def _distr_parameters_for_repr(self):
432423
return ["a"]
433424

pymc3/distributions/transforms.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def jacobian_det(self, rv_var, rv_value):
218218
s = at.nnet.softplus(-rv_value)
219219
return at.log(b - a) - 2 * s - rv_value
220220
else:
221-
return rv_value
221+
return at.ones_like(rv_value)
222222

223223

224224
interval = Interval
@@ -286,6 +286,11 @@ class StickBreaking(Transform):
286286
name = "stickbreaking"
287287

288288
def forward(self, rv_var, rv_value):
289+
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
290+
# If this variable is just a bunch of scalars/degenerate
291+
# Dirichlets, we can't transform it
292+
return rv_value
293+
289294
x = rv_value.T
290295
n = x.shape[0]
291296
lx = at.log(x)
@@ -294,6 +299,11 @@ def forward(self, rv_var, rv_value):
294299
return floatX(y.T)
295300

296301
def backward(self, rv_var, rv_value):
302+
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
303+
# If this variable is just a bunch of scalars/degenerate
304+
# Dirichlets, we can't transform it
305+
return rv_value
306+
297307
y = rv_value.T
298308
y = at.concatenate([y, -at.sum(y, 0, keepdims=True)])
299309
# "softmax" with vector support and no deprication warning:
@@ -302,6 +312,11 @@ def backward(self, rv_var, rv_value):
302312
return floatX(x.T)
303313

304314
def jacobian_det(self, rv_var, rv_value):
315+
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
316+
# If this variable is just a bunch of scalars/degenerate
317+
# Dirichlets, we can't transform it
318+
return at.ones_like(rv_value)
319+
305320
y = rv_value.T
306321
Km1 = y.shape[0] + 1
307322
sy = at.sum(y, 0, keepdims=True)

pymc3/model.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ def build_named_node_tree(graphs):
239239

240240
T = TypeVar("T", bound="ContextMeta")
241241

242+
no_transform_object = object()
243+
242244

243245
class ContextMeta(type):
244246
"""Functionality for objects that put themselves in a context using
@@ -1047,7 +1049,9 @@ def add_coords(self, coords):
10471049
else:
10481050
self.coords[name] = coords[name]
10491051

1050-
def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, transform=None):
1052+
def register_rv(
1053+
self, rv_var, name, data=None, total_size=None, dims=None, transform=no_transform_object
1054+
):
10511055
"""Register an (un)observed random variable with the model.
10521056
10531057
Parameters
@@ -1104,24 +1108,21 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans
11041108
if aesara.config.compute_test_value != "off":
11051109
value_var.tag.test_value = rv_var.tag.test_value
11061110

1107-
value_var.name = f"{rv_var.name}_value"
1111+
value_var.name = rv_var.name
11081112

11091113
rv_var.tag.value_var = value_var
11101114

11111115
# Make the value variable a transformed value variable,
11121116
# if there's an applicable transform
1113-
transform = transform or logp_transform(rv_var.owner.op)
1117+
if transform is no_transform_object:
1118+
transform = logp_transform(rv_var.owner.op)
11141119

11151120
if transform is not None:
11161121
value_var.tag.transform = transform
11171122
value_var.name = f"{value_var.name}_{transform.name}__"
11181123
if aesara.config.compute_test_value != "off":
11191124
value_var.tag.test_value = transform.forward(rv_var, value_var).tag.test_value
11201125

1121-
# The transformed variable needs to be a named variable in the
1122-
# model, too
1123-
self.named_vars[value_var.name] = value_var
1124-
11251126
self.add_random_variable(rv_var, dims)
11261127

11271128
return rv_var
@@ -1173,7 +1174,7 @@ def __getitem__(self, key):
11731174
except KeyError:
11741175
raise e
11751176

1176-
def makefn(self, outs, mode=None, transformed=True, *args, **kwargs):
1177+
def makefn(self, outs, mode=None, *args, **kwargs):
11771178
"""Compiles a Aesara function which returns ``outs`` and takes the variable
11781179
ancestors of ``outs`` as inputs.
11791180
@@ -1187,11 +1188,8 @@ def makefn(self, outs, mode=None, transformed=True, *args, **kwargs):
11871188
Compiled Aesara function
11881189
"""
11891190
with self:
1190-
vars = [
1191-
v if not transformed else getattr(v.tag, "transformed_var", v) for v in self.vars
1192-
]
11931191
return aesara.function(
1194-
vars,
1192+
self.vars,
11951193
outs,
11961194
allow_input_downcast=True,
11971195
on_unused_input="ignore",
@@ -1324,7 +1322,10 @@ def check_test_point(self, test_point=None, round_vals=2):
13241322

13251323
return Series(
13261324
{
1327-
rv.name: np.round(self.fn(logpt_sum(rv))(test_point), round_vals)
1325+
rv.name: np.round(
1326+
self.fn(logpt_sum(rv, getattr(rv.tag, "observations", None)))(test_point),
1327+
round_vals,
1328+
)
13281329
for rv in self.basic_RVs
13291330
},
13301331
name="Log-probability of test_point",

pymc3/sampling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,9 @@ def sample_posterior_predictive(
17231723
return {}
17241724

17251725
if not hasattr(_trace, "varnames"):
1726-
inputs_and_names = [(i, i.name) for i in rv_ancestors(vars_to_sample)]
1726+
inputs_and_names = [
1727+
(rv, rv.name) for rv in rv_ancestors(vars_to_sample, walk_past_rvs=True)
1728+
]
17271729
inputs, input_names = zip(*inputs_and_names)
17281730
else:
17291731
input_names = _trace.varnames

pymc3/step_methods/gibbs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020
from warnings import warn
2121

22-
import aesara.tensor as aet
22+
import aesara.tensor as at
2323

2424
from aesara.graph.basic import graph_inputs
2525
from numpy import arange, array, cumsum, empty, exp, max, nested_iters, searchsorted
@@ -81,7 +81,7 @@ def elemwise_logp(model, var):
8181
v_logp = logpt(v)
8282
if var in graph_inputs([v_logp]):
8383
terms.append(v_logp)
84-
return model.fn(aet.add(*terms))
84+
return model.fn(at.add(*terms))
8585

8686

8787
def categorical(prob, shape):

0 commit comments

Comments
 (0)