Skip to content

Commit f8fc0e2

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5f0cc28 commit f8fc0e2

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

pymc_experimental/inference/laplace.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def laplace(
103103
# See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html
104104
untransformed_m = remove_value_transforms(transformed_m)
105105
untransformed_vars = [untransformed_m[v.name] for v in vars]
106-
hessian = pm.find_hessian(
107-
point=map, vars=untransformed_vars, model=untransformed_m)
106+
hessian = pm.find_hessian(point=map, vars=untransformed_vars, model=untransformed_m)
108107

109108
if np.linalg.det(hessian) == 0:
110109
raise np.linalg.LinAlgError("Hessian is singular.")
@@ -119,8 +118,7 @@ def laplace(
119118

120119
data_vars = {}
121120
for i, var in enumerate(vars):
122-
data_vars[str(var)] = xr.DataArray(
123-
samples[:, :, i], dims=("chain", "draw"))
121+
data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw"))
124122

125123
coords = {"chain": np.arange(chains), "draw": np.arange(draws)}
126124
ds = xr.Dataset(data_vars, coords=coords)
@@ -138,15 +136,13 @@ def laplace(
138136
def addFitToInferenceData(vars, idata, mean, covariance):
139137
coord_names = [v.name for v in vars]
140138
# Convert to xarray DataArray
141-
mean_dataarray = xr.DataArray(
142-
mean, dims=["rows"], coords={"rows": coord_names})
139+
mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names})
143140
cov_dataarray = xr.DataArray(
144141
covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names}
145142
)
146143

147144
# Create xarray dataset
148-
dataset = xr.Dataset({"mean_vector": mean_dataarray,
149-
"covariance_matrix": cov_dataarray})
145+
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
150146

151147
idata.add_groups(fit=dataset)
152148

pymc_experimental/tests/test_laplace.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def test_laplace():
6060
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
6161

6262
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
63-
assert np.allclose(
64-
idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
63+
assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
6564

6665

6766
@pytest.mark.filterwarnings(
@@ -99,5 +98,4 @@ def test_laplace_only_fit():
9998
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
10099

101100
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
102-
assert np.allclose(
103-
idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
101+
assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)

0 commit comments

Comments
 (0)