Skip to content

Commit 196fb9d

Browse files
committed
Remove logcdfpt and simplify logp and logcdf helpers
These methods will only retrieve the RV logp and logcdf terms, and no longer attempt to replace any potential RVs in their arguments. A ValueError message is raised when the provided value is not compatible with the original RV dimensions
1 parent 7840581 commit 196fb9d

File tree

9 files changed

+75
-167
lines changed

9 files changed

+75
-167
lines changed

pymc/distributions/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414

1515
from pymc.distributions.logprob import ( # isort:skip
16-
_logcdf,
1716
logcdf,
1817
logp,
19-
logcdfpt,
2018
logp_transform,
2119
logpt,
2220
logpt_sum,
@@ -195,6 +193,5 @@
195193
"logp",
196194
"logp_transform",
197195
"logcdf",
198-
"_logcdf",
199196
"logpt_sum",
200197
]

pymc/distributions/discrete.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import aesara.tensor as at
1515
import numpy as np
1616

17-
from aeppl.logprob import _logprob
1817
from aesara.tensor.random.basic import (
1918
RandomVariable,
2019
bernoulli,
@@ -42,7 +41,7 @@
4241
normal_lcdf,
4342
)
4443
from pymc.distributions.distribution import Discrete
45-
from pymc.distributions.logprob import _logcdf
44+
from pymc.distributions.logprob import logcdf, logp
4645
from pymc.distributions.shape_utils import rv_size_is_none
4746
from pymc.math import sigmoid
4847

@@ -754,7 +753,7 @@ def logp(value, n, p):
754753
)
755754

756755
# Return Poisson when alpha gets very large.
757-
return at.switch(at.gt(alpha, 1e10), Poisson.logp(value, mu), negbinom)
756+
return at.switch(at.gt(alpha, 1e10), logp(Poisson.dist(mu=mu), value), negbinom)
758757

759758
def logcdf(value, n, p):
760759
"""
@@ -1371,7 +1370,7 @@ def logp(value, psi, theta):
13711370

13721371
logp_val = at.switch(
13731372
at.gt(value, 0),
1374-
at.log(psi) + _logprob(poisson, [value], None, None, None, theta),
1373+
at.log(psi) + logp(Poisson.dist(mu=theta), value),
13751374
at.logaddexp(at.log1p(-psi), at.log(psi) - theta),
13761375
)
13771376

@@ -1402,7 +1401,7 @@ def logcdf(value, psi, theta):
14021401
return bound(
14031402
at.logaddexp(
14041403
at.log1p(-psi),
1405-
at.log(psi) + _logcdf(poisson, value, {}, theta),
1404+
at.log(psi) + logcdf(Poisson.dist(mu=theta), value),
14061405
),
14071406
0 <= value,
14081407
0 <= psi,
@@ -1510,7 +1509,7 @@ def logp(value, psi, n, p):
15101509

15111510
logp_val = at.switch(
15121511
at.gt(value, 0),
1513-
at.log(psi) + _logprob(binomial, [value], None, None, None, n, p),
1512+
at.log(psi) + logp(Binomial.dist(n=n, p=p), value),
15141513
at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)),
15151514
)
15161515

@@ -1543,7 +1542,7 @@ def logcdf(value, psi, n, p):
15431542
return bound(
15441543
at.logaddexp(
15451544
at.log1p(-psi),
1546-
at.log(psi) + _logcdf(binomial, value, {}, n, p),
1545+
at.log(psi) + logcdf(Binomial.dist(n=n, p=p), value),
15471546
),
15481547
0 <= value,
15491548
value <= n,
@@ -1669,7 +1668,7 @@ def logp(value, psi, n, p):
16691668
return bound(
16701669
at.switch(
16711670
at.gt(value, 0),
1672-
at.log(psi) + _logprob(nbinom, [value], None, None, None, n, p),
1671+
at.log(psi) + logp(NegativeBinomial.dist(n=n, p=p), value),
16731672
at.logaddexp(at.log1p(-psi), at.log(psi) + n * at.log(p)),
16741673
),
16751674
0 <= value,
@@ -1696,7 +1695,9 @@ def logcdf(value, psi, n, p):
16961695
TensorVariable
16971696
"""
16981697
return bound(
1699-
at.logaddexp(at.log1p(-psi), at.log(psi) + _logcdf(nbinom, value, {}, n, p)),
1698+
at.logaddexp(
1699+
at.log1p(-psi), at.log(psi) + logcdf(NegativeBinomial.dist(n=n, p=p), value)
1700+
),
17001701
0 <= value,
17011702
0 <= psi,
17021703
psi <= 1,

pymc/distributions/distribution.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@
2323

2424
import aesara
2525

26-
from aeppl.logprob import _logprob
26+
from aeppl.logprob import _logcdf, _logprob
2727
from aesara.tensor.basic import as_tensor_variable
2828
from aesara.tensor.random.op import RandomVariable
2929
from aesara.tensor.random.var import RandomStateSharedVariable
3030
from aesara.tensor.var import TensorVariable
3131

3232
from pymc.aesaraf import change_rv_size
33-
from pymc.distributions import _logcdf
3433
from pymc.distributions.shape_utils import (
3534
Dims,
3635
Shape,
@@ -98,18 +97,18 @@ def _random(*args, **kwargs):
9897
if class_logp:
9998

10099
@_logprob.register(rv_type)
101-
def logp(op, value_var_list, *dist_params, **kwargs):
102-
_dist_params = dist_params[3:]
103-
value_var = value_var_list[0]
104-
return class_logp(value_var, *_dist_params)
100+
def logp(op, values, *dist_params, **kwargs):
101+
dist_params = dist_params[3:]
102+
(value,) = values
103+
return class_logp(value, *dist_params)
105104

106105
class_logcdf = clsdict.get("logcdf")
107106
if class_logcdf:
108107

109108
@_logcdf.register(rv_type)
110-
def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
111-
value_var = rvs_to_values.get(var, var)
112-
return class_logcdf(value_var, *dist_params, **kwargs)
109+
def logcdf(op, value, *dist_params, **kwargs):
110+
dist_params = dist_params[3:]
111+
return class_logcdf(value, *dist_params)
113112

114113
class_initval = clsdict.get("get_moment")
115114
if class_initval:

pymc/distributions/logprob.py

Lines changed: 12 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import numpy as np
2222

2323
from aeppl import factorized_joint_logprob
24+
from aeppl.logprob import logcdf as logcdf_aeppl
25+
from aeppl.logprob import logprob as logp_aeppl
2426
from aeppl.transforms import TransformValuesOpt
2527
from aesara import config
2628
from aesara.graph.basic import graph_inputs, io_toposort
@@ -36,7 +38,7 @@
3638
)
3739
from aesara.tensor.var import TensorVariable
3840

39-
from pymc.aesaraf import extract_rv_and_value_vars, floatX, rvs_to_value_vars
41+
from pymc.aesaraf import floatX
4042

4143

4244
@singledispatch
@@ -260,128 +262,18 @@ def logpt(
260262
return logp_var
261263

262264

263-
def logcdfpt(
264-
var: TensorVariable,
265-
rv_values: Optional[Union[TensorVariable, Dict[TensorVariable, TensorVariable]]] = None,
266-
*,
267-
scaling: bool = True,
268-
sum: bool = True,
269-
**kwargs,
270-
) -> TensorVariable:
271-
"""Create a measure-space (i.e. log-cdf) graph for a random variable at a given point.
272-
273-
Parameters
274-
==========
275-
var
276-
The `RandomVariable` output that determines the log-likelihood graph.
277-
rv_values
278-
A variable, or ``dict`` of variables, that represents the value of
279-
`var` in its log-likelihood. If no `rv_value` is provided,
280-
``var.tag.value_var`` will be checked and, when available, used.
281-
jacobian
282-
Whether or not to include the Jacobian term.
283-
scaling
284-
A scaling term to apply to the generated log-likelihood graph.
285-
transformed
286-
Apply transforms.
287-
sum
288-
Sum the log-likelihood.
289-
290-
"""
291-
if not isinstance(rv_values, Mapping):
292-
rv_values = {var: rv_values} if rv_values is not None else {}
293-
294-
rv_var, rv_value_var = extract_rv_and_value_vars(var)
265+
def logp(rv, value):
266+
"""Return the log-probability graph of a Random Variable"""
295267

296-
rv_value = rv_values.get(rv_var, rv_value_var)
268+
value = at.as_tensor_variable(value, dtype=rv.dtype)
269+
return logp_aeppl(rv, value)
297270

298-
if rv_var is not None and rv_value is None:
299-
raise ValueError(f"No value variable specified or associated with {rv_var}")
300271

301-
if rv_value is not None:
302-
rv_value = at.as_tensor(rv_value)
303-
304-
if rv_var is not None:
305-
# Make sure that the value is compatible with the random variable
306-
rv_value = rv_var.type.filter_variable(rv_value.astype(rv_var.dtype))
307-
308-
if rv_value_var is None:
309-
rv_value_var = rv_value
310-
311-
rv_node = rv_var.owner
312-
313-
rng, size, dtype, *dist_params = rv_node.inputs
314-
315-
# Here, we plug the actual random variable into the log-likelihood graph,
316-
# because we want a log-likelihood graph that only contains
317-
# random variables. This is important, because a random variable's
318-
# parameters can contain random variables themselves.
319-
# Ultimately, with a graph containing only random variables and
320-
# "deterministics", we can simply replace all the random variables with
321-
# their value variables and be done.
322-
tmp_rv_values = rv_values.copy()
323-
tmp_rv_values[rv_var] = rv_var
324-
325-
logp_var = _logcdf(rv_node.op, rv_var, tmp_rv_values, *dist_params, **kwargs)
326-
327-
transform = getattr(rv_value_var.tag, "transform", None) if rv_value_var else None
328-
329-
# Replace random variables with their value variables
330-
replacements = rv_values.copy()
331-
replacements.update({rv_var: rv_value, rv_value_var: rv_value})
332-
333-
(logp_var,), _ = rvs_to_value_vars(
334-
(logp_var,),
335-
apply_transforms=False,
336-
initial_replacements=replacements,
337-
)
338-
339-
if sum:
340-
logp_var = at.sum(logp_var)
341-
342-
if scaling:
343-
logp_var *= _get_scaling(
344-
getattr(rv_var.tag, "total_size", None), rv_value.shape, rv_value.ndim
345-
)
346-
347-
# Recompute test values for the changes introduced by the replacements
348-
# above.
349-
if config.compute_test_value != "off":
350-
for node in io_toposort(graph_inputs((logp_var,)), (logp_var,)):
351-
compute_test_value(node)
352-
353-
if rv_var.name is not None:
354-
logp_var.name = f"__logp_{rv_var.name}"
355-
356-
return logp_var
357-
358-
359-
def logp(var, rv_values, **kwargs):
360-
"""Create a log-probability graph."""
361-
362-
# Attach the value_var to the tag of var when it does not have one
363-
if not hasattr(var.tag, "value_var"):
364-
if isinstance(rv_values, Mapping):
365-
value_var = rv_values[var]
366-
else:
367-
value_var = rv_values
368-
var.tag.value_var = at.as_tensor_variable(value_var, dtype=var.dtype)
369-
370-
return logpt(var, rv_values, **kwargs)
371-
372-
373-
def logcdf(var, rv_values, **kwargs):
374-
"""Create a log-CDF graph."""
375-
376-
# Attach the value_var to the tag of var when it does not have one
377-
if not hasattr(var.tag, "value_var"):
378-
if isinstance(rv_values, Mapping):
379-
value_var = rv_values[var]
380-
else:
381-
value_var = rv_values
382-
var.tag.value_var = at.as_tensor_variable(value_var, dtype=var.dtype)
272+
def logcdf(rv, value):
273+
"""Return the log-cdf graph of a Random Variable"""
383274

384-
return logcdfpt(var, rv_values, **kwargs)
275+
value = at.as_tensor_variable(value, dtype=rv.dtype)
276+
return logcdf_aeppl(rv, value)
385277

386278

387279
@singledispatch
@@ -402,4 +294,5 @@ def logpt_sum(*args, **kwargs):
402294
Subclasses can use this to improve the speed of logp evaluations
403295
if only the sum of the logp values is needed.
404296
"""
297+
# TODO: Deprecate this
405298
return logpt(*args, sum=True, **kwargs)

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2102,7 +2102,7 @@ def logp(value, mu, W, alpha, tau):
21022102
TensorVariable
21032103
"""
21042104

2105-
sparse = isinstance(W, aesara.sparse.SparseConstant)
2105+
sparse = isinstance(W, (aesara.sparse.SparseConstant, aesara.sparse.SparseVariable))
21062106

21072107
if sparse:
21082108
D = sp_sum(W, axis=0)

0 commit comments

Comments
 (0)