Skip to content

Commit b0c58b4

Browse files
committed
Lint fixes
1 parent 50187f3 commit b0c58b4

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

pymc_experimental/model/marginal_model.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
22
from typing import Sequence, Tuple, Union
33

4-
from arviz import dict_to_dataset
54
import numpy as np
6-
import pytensor.tensor as pt
75
import pymc
6+
import pytensor.tensor as pt
7+
from arviz import dict_to_dataset
88
from pymc import SymbolicRandomVariable
99
from pymc.backends.arviz import coords_and_dims_for_inferencedata
1010
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
@@ -13,7 +13,7 @@
1313
from pymc.logprob.basic import conditional_logp
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
16-
from pymc.pytensorf import constant_fold, compile_pymc, inputvars
16+
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
1717
from pymc.util import dataset_to_point_list, treedict
1818
from pytensor import Mode
1919
from pytensor.compile import SharedVariable
@@ -248,7 +248,9 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
248248
# Raise errors and warnings immediately
249249
self.clone()._marginalize(user_warnings=True)
250250

251-
def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_inferencedata=True):
251+
def unmunmarginalize(
252+
self, idata, var_names=None, include_samples=False, extend_inferencedata=True
253+
):
252254
"""Computes log-likelihoods of marginalized variables conditioned on parameters
253255
of the model given InferenceData with posterior group
254256
@@ -276,7 +278,9 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
276278
posterior = idata.posterior
277279

278280
# Remove Deterministics
279-
posterior_values = posterior[[rv.name for rv in mm.free_RVs if rv not in self.marginalized_rvs]]
281+
posterior_values = posterior[
282+
[rv.name for rv in mm.free_RVs if rv not in self.marginalized_rvs]
283+
]
280284

281285
sample_dims = ("chain", "draw")
282286
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
@@ -306,10 +310,11 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
306310

307311
# TODO: Handle constants
308312
# TODO: Handle transformed variables
309-
joint_logp_op = OpFromGraph([marginalized_value] + other_values, [joint_logp], inline=True)
313+
joint_logp_op = OpFromGraph(
314+
[marginalized_value] + other_values, [joint_logp], inline=True
315+
)
310316
joint_logps = [
311-
joint_logp_op(rv_domain_tensor[i], *other_values)
312-
for i in range(len(rv_domain))
317+
joint_logp_op(rv_domain_tensor[i], *other_values) for i in range(len(rv_domain))
313318
]
314319

315320
rv_loglike_fn = None
@@ -333,20 +338,17 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
333338
logps, samples = zip(*logvs)
334339
logps = np.array(logps)
335340
rv_dict[rv.name] = np.reshape(
336-
samples,
337-
tuple(len(coord) for coord in stacked_dims.values())
341+
samples, tuple(len(coord) for coord in stacked_dims.values())
338342
)
339343
rv_dims_dict[rv.name] = sample_dims
340344
rv_dict["lp_" + rv.name] = np.reshape(
341-
logps,
342-
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
345+
logps, tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
343346
)
344347
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
345348
else:
346349
logps = np.array(logvs)
347350
rv_dict["lp_" + rv.name] = np.reshape(
348-
logps,
349-
tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
351+
logps, tuple(len(coord) for coord in stacked_dims.values()) + logps.shape[1:]
350352
)
351353
rv_dims_dict["lp_" + rv.name] = sample_dims + ("lp_" + rv.name + "_dims",)
352354

@@ -361,7 +363,7 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
361363
)
362364

363365
if extend_inferencedata:
364-
rv_dict = {k:(rv_dims_dict[k], v) for (k,v) in rv_dict.items()}
366+
rv_dict = {k: (rv_dims_dict[k], v) for (k, v) in rv_dict.items()}
365367
idata = idata.posterior.assign(**rv_dict)
366368
return idata
367369
else:

pymc_experimental/tests/model/test_marginal_model.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ def test_marginalized_bernoulli_logp():
5252

5353
idx = pm.Bernoulli.dist(0.7, name="idx")
5454
y = pm.Normal.dist(mu=mu[idx], sigma=1.0, name="y")
55-
marginal_rv_node = FiniteDiscreteMarginalRV([mu], [idx, y], ndim_supp=None, n_updates=0,)(
55+
marginal_rv_node = FiniteDiscreteMarginalRV(
56+
[mu],
57+
[idx, y],
58+
ndim_supp=None,
59+
n_updates=0,
60+
)(
5661
mu
5762
)[0].owner
5863

@@ -250,6 +255,7 @@ def test_marginalized_change_point_model_sampling(disaster_model):
250255
rtol=1e-2,
251256
)
252257

258+
253259
@pytest.mark.slow
254260
@pytest.mark.filterwarnings("error")
255261
def test_unmarginalized_basic(disaster_model):
@@ -269,6 +275,7 @@ def test_unmarginalized_basic(disaster_model):
269275
assert idata.switchpoint.shape == idata.early_mean.shape
270276
assert idata.lp_switchpoint.shape == idata.switchpoint.shape + (len(years),)
271277

278+
272279
@pytest.mark.filterwarnings("error")
273280
def test_not_supported_marginalized():
274281
"""Marginalized graphs with non-Elemwise Operations are not supported as they

0 commit comments

Comments
 (0)