Skip to content

Commit 0a172c8

Browse files
kc611ricardoV94
authored andcommitted
Added aeppl based log-likelihood graph generation and aeppl based transforms
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 3ab1c00 commit 0a172c8

40 files changed

+595
-1095
lines changed

conda-envs/environment-dev-py37.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-dev-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-dev-py39.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-test-py37.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-test-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/environment-test-py39.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools

conda-envs/windows-environment-dev-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- defaults
55
dependencies:
66
# base dependencies (see install guide for Windows)
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.4
910
- cachetools>=4.2.1

conda-envs/windows-environment-test-py38.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ channels:
44
- defaults
55
dependencies:
66
# base dependencies (see install guide for Windows)
7+
- aeppl>=0.0.13
78
- aesara>=2.2.2
89
- arviz>=0.11.2
910
- cachetools

pymc/aesaraf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def transform_replacements(var, replacements):
377377
# potential replacements
378378
return [rv_value_var]
379379

380-
trans_rv_value = transform.backward(rv_var, rv_value_var)
380+
trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
381381
replacements[var] = trans_rv_value
382382

383383
# Walk the transformed variable and make replacements

pymc/bart/bart.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara.tensor as at
1516
import numpy as np
1617

18+
from aeppl.logprob import _logprob
1719
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
1820
from pandas import DataFrame, Series
1921

@@ -146,6 +148,20 @@ def __new__(
146148
def dist(cls, *params, **kwargs):
147149
return super().dist(params, **kwargs)
148150

151+
def logp(x, *inputs):
152+
"""Calculate log probability.
153+
154+
Parameters
155+
----------
156+
x: numeric, TensorVariable
157+
Value for which log-probability is calculated.
158+
159+
Returns
160+
-------
161+
TensorVariable
162+
"""
163+
return at.zeros_like(x)
164+
149165

150166
def preprocess_XY(X, Y):
151167
if isinstance(Y, (Series, DataFrame)):
@@ -156,3 +172,10 @@ def preprocess_XY(X, Y):
156172
Y = Y.astype(float)
157173
X = X.astype(float)
158174
return X, Y
175+
176+
177+
@_logprob.register(BARTRV)
178+
def logp(op, value_var, *dist_params, **kwargs):
179+
_dist_params = dist_params[3:]
180+
value_var = value_var[0]
181+
return BART.logp(value_var, *_dist_params)

pymc/distributions/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from pymc.distributions.logprob import ( # isort:skip
1616
_logcdf,
17-
_logp,
1817
logcdf,
1918
logp,
19+
logcdfpt,
2020
logp_transform,
2121
logpt,
2222
logpt_sum,
@@ -193,7 +193,6 @@
193193
"PolyaGamma",
194194
"logpt",
195195
"logp",
196-
"_logp",
197196
"logp_transform",
198197
"logcdf",
199198
"_logcdf",

pymc/distributions/bound.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
import numpy as np
1616

17+
from aeppl.logprob import logprob
1718
from aesara.tensor import as_tensor_variable
1819
from aesara.tensor.random.op import RandomVariable
1920
from aesara.tensor.var import TensorVariable
2021

2122
from pymc.aesaraf import floatX, intX
22-
from pymc.distributions import _logp
2323
from pymc.distributions.continuous import BoundedContinuous
2424
from pymc.distributions.dist_math import bound
2525
from pymc.distributions.distribution import Continuous, Discrete
@@ -46,7 +46,7 @@ def rng_fn(cls, rng, distribution, lower, upper, size):
4646

4747
class _ContinuousBounded(BoundedContinuous):
4848
rv_op = boundrv
49-
bound_args_indices = [1, 2]
49+
bound_args_indices = [4, 5]
5050

5151
def logp(value, distribution, lower, upper):
5252
"""
@@ -67,7 +67,7 @@ def logp(value, distribution, lower, upper):
6767
-------
6868
TensorVariable
6969
"""
70-
logp = _logp(distribution.owner.op, value, {}, *distribution.owner.inputs[3:])
70+
logp = logprob(distribution, value)
7171
return bound(logp, (value >= lower), (value <= upper))
7272

7373

@@ -107,7 +107,7 @@ def logp(value, distribution, lower, upper):
107107
-------
108108
TensorVariable
109109
"""
110-
logp = _logp(distribution.owner.op, value, {}, *distribution.owner.inputs[3:])
110+
logp = logprob(distribution, value)
111111
return bound(logp, (value >= lower), (value <= upper))
112112

113113

@@ -166,6 +166,7 @@ def __new__(
166166
raise ValueError("Given dims do not exist in model coordinates.")
167167

168168
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
169+
distribution.tag.ignore_logprob = True
169170

170171
if isinstance(distribution.owner.op, Continuous):
171172
res = _ContinuousBounded(
@@ -200,7 +201,7 @@ def dist(
200201

201202
cls._argument_checks(distribution, **kwargs)
202203
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
203-
204+
distribution.tag.ignore_logprob = True
204205
if isinstance(distribution.owner.op, Continuous):
205206
res = _ContinuousBounded.dist(
206207
[distribution, lower, upper],

0 commit comments

Comments
 (0)