Skip to content

Commit 6c8a2b3

Browse files
Reinstate log-likelihood transforms
1 parent 34a66c7 commit 6c8a2b3

File tree

7 files changed

+259
-261
lines changed

7 files changed

+259
-261
lines changed

pymc3/backends/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,18 @@ def __init__(self, name, model=None, vars=None, test_point=None):
6161
model = modelcontext(model)
6262
self.model = model
6363
if vars is None:
64-
vars = [getattr(v.tag, "value_var", v) for v in model.unobserved_RVs]
64+
vars = []
65+
for v in model.unobserved_RVs:
66+
var = getattr(v.tag, "value_var", v)
67+
transform = getattr(var.tag, "transform", None)
68+
if transform:
69+
# We need to create and add an un-transformed version of
70+
# each transformed variable
71+
untrans_var = transform.backward(var)
72+
untrans_var.name = v.name
73+
vars.append(untrans_var)
74+
vars.append(var)
75+
6576
self.vars = vars
6677
self.varnames = [var.name for var in vars]
6778
self.fn = model.fastfn(vars)

pymc3/distributions/__init__.py

Lines changed: 125 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import singledispatch
15+
from itertools import chain
1516
from typing import Generator, List, Optional, Tuple, Union
1617

1718
import aesara.tensor as aet
@@ -31,6 +32,11 @@
3132
]
3233

3334

35+
@singledispatch
36+
def logp_transform(op, inputs):
37+
return None
38+
39+
3440
def _get_scaling(total_size, shape, ndim):
3541
"""
3642
Gets scaling constant for logp
@@ -135,7 +141,6 @@ def change_rv_size(
135141

136142
def rv_log_likelihood_args(
137143
rv_var: TensorVariable,
138-
rv_value: Optional[TensorVariable] = None,
139144
transformed: Optional[bool] = True,
140145
) -> Tuple[TensorVariable, TensorVariable]:
141146
"""Get a `RandomVariable` and its corresponding log-likelihood `TensorVariable` value.
@@ -146,38 +151,24 @@ def rv_log_likelihood_args(
146151
A variable corresponding to a `RandomVariable`, whether directly or
147152
indirectly (e.g. an observed variable that's the output of an
148153
`Observed` `Op`).
149-
rv_value
150-
The measure-space input `TensorVariable` (i.e. "input" to a
151-
log-likelihood).
152154
transformed
153155
When ``True``, return the transformed value var.
154156
155157
Returns
156158
=======
157159
The first value in the tuple is the `RandomVariable`, and the second is the
158-
measure-space variable that corresponds with the latter. The first is used
159-
to determine the log likelihood graph and the second is the "input"
160-
parameter to that graph. In the case of an observed `RandomVariable`, the
161-
"input" is actual data; in all other cases, it's just another
162-
`TensorVariable`.
160+
measure-space variable that corresponds with the latter (i.e. the "value"
161+
variable).
163162
164163
"""
165164

166-
if rv_value is None:
167-
if rv_var.owner and isinstance(rv_var.owner.op, Observed):
168-
rv_var, rv_value = rv_var.owner.inputs
169-
elif hasattr(rv_var.tag, "value_var"):
170-
rv_value = rv_var.tag.value_var
171-
else:
172-
return rv_var, None
173-
174-
rv_value = aet.as_tensor_variable(rv_value)
175-
176-
transform = getattr(rv_value.tag, "transform", None)
177-
if transformed and transform:
178-
rv_value = transform.forward(rv_value)
179-
180-
return rv_var, rv_value
165+
if rv_var.owner and isinstance(rv_var.owner.op, Observed):
166+
return tuple(rv_var.owner.inputs)
167+
elif hasattr(rv_var.tag, "value_var"):
168+
rv_value = rv_var.tag.value_var
169+
return rv_var, rv_value
170+
else:
171+
return rv_var, None
181172

182173

183174
def rv_ancestors(graphs: List[TensorVariable]) -> Generator[TensorVariable, None, None]:
@@ -197,22 +188,53 @@ def strip_observed(x: TensorVariable) -> TensorVariable:
197188
return x
198189

199190

200-
def sample_to_measure_vars(graphs: List[TensorVariable]) -> List[TensorVariable]:
201-
"""Replace `RandomVariable` terms in graphs with their measure-space counterparts."""
191+
def sample_to_measure_vars(
192+
graphs: List[TensorVariable],
193+
) -> Tuple[List[TensorVariable], List[TensorVariable]]:
194+
"""Replace sample-space variables in graphs with their measure-space counterparts.
195+
196+
Sample-space variables are `TensorVariable` outputs of `RandomVariable`
197+
`Op`s. Measure-space variables are `TensorVariable`s that correspond to
198+
the value of a sample-space variable in a likelihood function (e.g. ``x``
199+
in ``p(X = x)``, where ``X`` is the corresponding sample-space variable).
200+
(``x`` is also the variable found in ``rv_var.tag.value_var``, so this
201+
function could also be called ``sample_to_value_vars``.)
202+
203+
Parameters
204+
==========
205+
graphs
206+
The graphs in which random variables are to be replaced by their
207+
measure variables.
208+
209+
Returns
210+
=======
211+
Tuple containing the transformed graphs and a ``dict`` of the replacements
212+
that were made.
213+
"""
202214
replace = {}
203-
for anc in rv_ancestors(graphs):
204-
measure_var = getattr(anc.tag, "value_var", None)
205-
if measure_var is not None:
206-
replace[anc] = measure_var
215+
for anc in chain(rv_ancestors(graphs), graphs):
216+
217+
if not (anc.owner and isinstance(anc.owner.op, RandomVariable)):
218+
continue
219+
220+
_, value_var = rv_log_likelihood_args(anc)
221+
222+
if value_var is not None:
223+
replace[anc] = value_var
224+
225+
if replace:
226+
measure_graphs = clone_replace(graphs, replace=replace)
227+
else:
228+
measure_graphs = graphs
207229

208-
dist_params = clone_replace(graphs, replace=replace)
209-
return dist_params
230+
return measure_graphs, replace
210231

211232

212233
def logpt(
213234
rv_var: TensorVariable,
214235
rv_value: Optional[TensorVariable] = None,
215-
jacobian: bool = True,
236+
jacobian: Optional[bool] = True,
237+
transformed: Optional[bool] = True,
216238
scaling: Optional[bool] = True,
217239
**kwargs,
218240
) -> TensorVariable:
@@ -228,29 +250,40 @@ def logpt(
228250
rv_var
229251
The `RandomVariable` output that determines the log-likelihood graph.
230252
rv_value
231-
The input variable for the log-likelihood graph.
253+
The input variable for the log-likelihood graph. If `rv_value` is
254+
a transformed variable, its transformations will be applied.
255+
If no value is provided, `rv_var.tag.value_var` will be checked and,
256+
when available, used.
232257
jacobian
233258
Whether or not to include the Jacobian term.
259+
transformed
260+
Return the transformed version of the log-likelihood graph.
234261
scaling
235262
A scaling term to apply to the generated log-likelihood graph.
236263
237264
"""
238265

239-
rv_var, rv_value = rv_log_likelihood_args(rv_var, rv_value)
266+
rv_var, rv_value_var = rv_log_likelihood_args(rv_var)
267+
268+
if rv_value is None:
269+
rv_value = rv_value_var
270+
else:
271+
rv_value = aet.as_tensor(rv_value)
272+
240273
rv_node = rv_var.owner
241274

242275
if not rv_node:
243276
raise TypeError("rv_var must be the output of a RandomVariable Op")
244277

245278
if not isinstance(rv_node.op, RandomVariable):
246279

280+
# This will probably need another generic function...
247281
if isinstance(rv_node.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)):
248282

249283
raise NotImplementedError("Missing value support is incomplete")
250284

251285
# "Flatten" and sum an array of indexed RVs' log-likelihoods
252286
rv_var, missing_values = rv_node.inputs
253-
rv_value = rv_var.tag.value_var
254287

255288
missing_values = missing_values.data
256289
logp_var = aet.sum(
@@ -268,28 +301,40 @@ def logpt(
268301

269302
return aet.zeros_like(rv_var)
270303

304+
if rv_value_var is None:
305+
raise NotImplementedError(f"The log-likelihood for {rv_var} is undefined")
306+
307+
# This case should be reached when `rv_var` is either the result of an
308+
# `Observed` or a `RandomVariable` `Op`
271309
rng, size, dtype, *dist_params = rv_node.inputs
272310

273-
dist_params = sample_to_measure_vars(dist_params)
311+
dist_params, replacements = sample_to_measure_vars(dist_params)
274312

275-
if jacobian:
276-
logp_var = _logp(rv_node.op, rv_value, *dist_params, **kwargs)
277-
else:
278-
logp_var = _logp_nojac(rv_node.op, rv_value, *dist_params, **kwargs)
313+
logp_var = _logp(rv_node.op, rv_value_var, *dist_params, **kwargs)
279314

280-
# Replace `RandomVariable` ancestors with their corresponding
281-
# log-likelihood input variables
282-
lik_replacements = [
283-
(v, v.tag.value_var)
284-
for v in ancestors([logp_var])
285-
if v.owner and isinstance(v.owner.op, RandomVariable) and getattr(v.tag, "value_var", None)
286-
]
315+
# If any of the measure vars are transformed measure-space variables
316+
# (signified by having a `transform` value in their tags), then we apply
317+
# the their transforms and add their Jacobians (when enabled)
318+
if transformed:
319+
logp_var = transform_logp(
320+
logp_var,
321+
tuple(replacements.values()) + (rv_value_var,),
322+
)
323+
324+
transform = getattr(rv_value_var.tag, "transform", None)
325+
326+
if transform and jacobian:
327+
transformed_jacobian = transform.jacobian_det(rv_value_var)
328+
if transformed_jacobian:
329+
if logp_var.ndim > transformed_jacobian.ndim:
330+
logp_var = logp_var.sum(axis=-1)
331+
logp_var += transformed_jacobian
287332

288-
(logp_var,) = clone_replace([logp_var], replace=lik_replacements)
333+
(logp_var,) = clone_replace([logp_var], replace={rv_value_var: rv_value})
289334

290335
if scaling:
291336
logp_var *= _get_scaling(
292-
getattr(rv_var.tag, "total_size", None), rv_value.shape, rv_value.ndim
337+
getattr(rv_var.tag, "total_size", None), rv_value_var.shape, rv_value_var.ndim
293338
)
294339

295340
if rv_var.name is not None:
@@ -298,6 +343,25 @@ def logpt(
298343
return logp_var
299344

300345

346+
def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> TensorVariable:
347+
"""Transform the inputs of a log-likelihood graph."""
348+
trans_replacements = {}
349+
for measure_var in inputs:
350+
351+
transform = getattr(measure_var.tag, "transform", None)
352+
353+
if transform is None:
354+
continue
355+
356+
trans_rv_value = transform.backward(measure_var)
357+
trans_replacements[measure_var] = trans_rv_value
358+
359+
if trans_replacements:
360+
(logp_var,) = clone_replace([logp_var], trans_replacements)
361+
362+
return logp_var
363+
364+
301365
@singledispatch
302366
def _logp(op, value, *dist_params, **kwargs):
303367
"""Create a log-likelihood graph.
@@ -310,20 +374,27 @@ def _logp(op, value, *dist_params, **kwargs):
310374
return aet.zeros_like(value)
311375

312376

313-
def logcdf(rv_var, rv_value, **kwargs):
377+
def logcdf(rv_var, rv_value, transformed=True, jacobian=True, **kwargs):
314378
"""Create a log-CDF graph."""
315379

316-
rv_var, rv_value = rv_log_likelihood_args(rv_var, rv_value)
380+
rv_var, rv_value = rv_log_likelihood_args(rv_var)
317381
rv_node = rv_var.owner
318382

319383
if not rv_node:
320384
raise TypeError()
321385

322386
rng, size, dtype, *dist_params = rv_node.inputs
323387

324-
dist_params = sample_to_measure_vars(dist_params)
388+
dist_params, replacements = sample_to_measure_vars(dist_params)
389+
390+
logp_var = _logcdf(rv_node.op, rv_value, *dist_params, **kwargs)
325391

326-
return _logcdf(rv_node.op, rv_value, *dist_params, **kwargs)
392+
if transformed:
393+
logp_var = transform_logp(
394+
logp_var, tuple(replacements.values()) + (rv_value,), jacobian=jacobian
395+
)
396+
397+
return logp_var
327398

328399

329400
@singledispatch
@@ -338,38 +409,6 @@ def _logcdf(op, value, *args, **kwargs):
338409
raise NotImplementedError()
339410

340411

341-
def logp_nojac(rv_var, rv_value=None, **kwargs):
342-
"""Create a graph of the log-likelihood that doesn't include the Jacobian."""
343-
344-
rv_var, rv_value = rv_log_likelihood_args(rv_var, rv_value)
345-
rv_node = rv_var.owner
346-
347-
if not rv_node:
348-
raise TypeError()
349-
350-
rng, size, dtype, *dist_params = rv_node.inputs
351-
352-
dist_params = sample_to_measure_vars(dist_params)
353-
354-
return _logp_nojac(rv_node.op, rv_value, **kwargs)
355-
356-
357-
@singledispatch
358-
def _logp_nojac(op, value, *args, **kwargs):
359-
"""Return the logp, but do not include a jacobian term for transforms.
360-
361-
If we use different parametrizations for the same distribution, we
362-
need to add the determinant of the jacobian of the transformation
363-
to make sure the densities still describe the same distribution.
364-
However, MAP estimates are not invariant with respect to the
365-
parameterization, we need to exclude the jacobian terms in this case.
366-
367-
This function should be overwritten in base classes for transformed
368-
distributions.
369-
"""
370-
return logpt(op, value, *args, **kwargs)
371-
372-
373412
def logpt_sum(rv_var: TensorVariable, rv_value: Optional[TensorVariable] = None, **kwargs):
374413
"""Return the sum of the logp values for the given observations.
375414

0 commit comments

Comments
 (0)