|
103 | 103 | logpt,
|
104 | 104 | logpt_sum,
|
105 | 105 | )
|
106 |
| -from pymc3.math import kronecker, logsumexp |
| 106 | +from pymc3.math import kronecker |
107 | 107 | from pymc3.model import Deterministic, Model, Point
|
108 | 108 | from pymc3.tests.helpers import select_by_precision
|
109 | 109 | from pymc3.vartypes import continuous_types
|
@@ -858,24 +858,33 @@ def check_selfconsistency_discrete_logcdf(
|
858 | 858 | """
|
859 | 859 | Check that logcdf of discrete distributions matches sum of logps up to value
|
860 | 860 | """
|
| 861 | + # This test only works for scalar random variables |
| 862 | + assert distribution.rv_op.ndim_supp == 0 |
| 863 | + |
861 | 864 | domains = paramdomains.copy()
|
862 | 865 | domains["value"] = domain
|
863 | 866 | if decimal is None:
|
864 | 867 | decimal = select_by_precision(float64=6, float32=3)
|
| 868 | + |
| 869 | + model, param_vars = build_model(distribution, domain, paramdomains) |
| 870 | + dist_logcdf = model.fastfn(logpt(model["value"], cdf=True)) |
| 871 | + dist_logp = model.fastfn(logpt(model["value"])) |
| 872 | + |
865 | 873 | for pt in product(domains, n_samples=n_samples):
|
866 | 874 | params = dict(pt)
|
867 | 875 | if skip_params_fn(params):
|
868 | 876 | continue
|
869 | 877 | value = params.pop("value")
|
870 | 878 | values = np.arange(domain.lower, value + 1)
|
871 |
| - dist = distribution.dist(**params) |
872 |
| - # This only works for scalar random variables |
873 |
| - assert dist.owner.op.ndim_supp == 0 |
874 |
| - values_dist = change_rv_size(dist, values.shape) |
| 879 | + |
| 880 | + # Update shared parameter variables in logp/logcdf function |
| 881 | + for param_name, param_value in params.items(): |
| 882 | + param_vars[param_name].set_value(param_value) |
| 883 | + |
875 | 884 | with aesara.config.change_flags(mode=Mode("py")):
|
876 | 885 | assert_almost_equal(
|
877 |
| - logcdf(dist, value).eval(), |
878 |
| - logsumexp(logpt(values_dist, values), keepdims=False).eval(), |
| 886 | + dist_logcdf({"value": value}), |
| 887 | + scipy.special.logsumexp([dist_logp({"value": value}) for value in values]), |
879 | 888 | decimal=decimal,
|
880 | 889 | err_msg=str(pt),
|
881 | 890 | )
|
|
0 commit comments