Skip to content

Commit e15c065

Browse files
linting
1 parent f0e7030 commit e15c065

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

pymc3/quadratic_approximation.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,17 @@
1414

1515
"""Functions for Quadratic Approximation."""
1616

17+
import arviz as az
1718
import numpy as np
1819
import scipy
19-
import arviz as az
2020

21-
__all__ = [
22-
"quadratic_approximation"
23-
]
21+
__all__ = ["quadratic_approximation"]
2422

25-
from pymc3.tuning import find_MAP, find_hessian
23+
from pymc3.tuning import find_hessian, find_MAP
2624

2725

2826
def quadratic_approximation(vars, n_chains=2, n_samples=10_000):
29-
""" Finds the quadratic approximation to the posterior, also known as the Laplace approximation.
27+
"""Finds the quadratic approximation to the posterior, also known as the Laplace approximation.
3028
3129
NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables.
3230
The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation.
@@ -64,6 +62,6 @@ def quadratic_approximation(vars, n_chains=2, n_samples=10_000):
6462
i = 0
6563
for v in vars:
6664
var_size = map[v.name].size
67-
samples[v.name] = draws[:, :, i:i + var_size].squeeze()
65+
samples[v.name] = draws[:, :, i : i + var_size].squeeze()
6866
i += var_size
6967
return az.convert_to_inference_data(samples), posterior

pymc3/tests/test_quadratic_approximation.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import arviz as az
12
import numpy as np
3+
24
import pymc3 as pm
3-
import arviz as az
45

56
from pymc3.tests.helpers import SeededTest
67

@@ -38,12 +39,7 @@ def test_hdi_contains_parameters_in_linear_regression():
3839
with pm.Model() as lm:
3940
weights = pm.Normal("weights", mu=0, sigma=1, shape=M)
4041
noise = pm.Exponential("noise", lam=1)
41-
y_observed = pm.Normal(
42-
"y_observed",
43-
mu=X @ weights,
44-
sigma=noise,
45-
observed=y
46-
)
42+
y_observed = pm.Normal("y_observed", mu=X @ weights, sigma=noise, observed=y)
4743

4844
idata, _ = pm.quadratic_approximation([weights, noise])
4945

0 commit comments

Comments
 (0)