Skip to content

Commit 5f0cc28

Browse files
committed
Parameter draws is not optional with default value 1000
1 parent 071b04b commit 5f0cc28

File tree

2 files changed

+53
-9
lines changed

2 files changed

+53
-9
lines changed

pymc_experimental/inference/laplace.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
def laplace(
3434
vars: Sequence[Variable],
35-
draws=1_000,
35+
draws: Optional[int] = 1000,
3636
model=None,
3737
random_seed: Optional[RandomSeed] = None,
3838
progressbar=True,
@@ -49,9 +49,9 @@ def laplace(
4949
vars : Sequence[Variable]
5050
A sequence of variables for which the Laplace approximation of the posterior distribution
5151
is to be created.
52-
draws : int, optional, default=1_000
52+
draws : Optional[int] with default=1_000
5353
The number of draws to sample from the posterior distribution for creating the approximation.
54-
For draws=0 only the fit of the Laplace approximation is returned
54+
For draws=None only the fit of the Laplace approximation is returned
5555
model : object, optional, default=None
5656
The model object that defines the posterior distribution. If None, the default model will be used.
5757
random_seed : Optional[RandomSeed], optional, default=None
@@ -103,7 +103,8 @@ def laplace(
103103
# See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html
104104
untransformed_m = remove_value_transforms(transformed_m)
105105
untransformed_vars = [untransformed_m[v.name] for v in vars]
106-
hessian = pm.find_hessian(point=map, vars=untransformed_vars, model=untransformed_m)
106+
hessian = pm.find_hessian(
107+
point=map, vars=untransformed_vars, model=untransformed_m)
107108

108109
if np.linalg.det(hessian) == 0:
109110
raise np.linalg.LinAlgError("Hessian is singular.")
@@ -113,12 +114,13 @@ def laplace(
113114

114115
chains = 1
115116

116-
if draws != 0:
117+
if draws is not None:
117118
samples = rng.multivariate_normal(mean, cov, size=(chains, draws))
118119

119120
data_vars = {}
120121
for i, var in enumerate(vars):
121-
data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw"))
122+
data_vars[str(var)] = xr.DataArray(
123+
samples[:, :, i], dims=("chain", "draw"))
122124

123125
coords = {"chain": np.arange(chains), "draw": np.arange(draws)}
124126
ds = xr.Dataset(data_vars, coords=coords)
@@ -136,13 +138,15 @@ def laplace(
136138
def addFitToInferenceData(vars, idata, mean, covariance):
137139
coord_names = [v.name for v in vars]
138140
# Convert to xarray DataArray
139-
mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names})
141+
mean_dataarray = xr.DataArray(
142+
mean, dims=["rows"], coords={"rows": coord_names})
140143
cov_dataarray = xr.DataArray(
141144
covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names}
142145
)
143146

144147
# Create xarray dataset
145-
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
148+
dataset = xr.Dataset({"mean_vector": mean_dataarray,
149+
"covariance_matrix": cov_dataarray})
146150

147151
idata.add_groups(fit=dataset)
148152

pymc_experimental/tests/test_laplace.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,44 @@ def test_laplace():
6060
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
6161

6262
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
63-
assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
63+
assert np.allclose(
64+
idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
65+
66+
67+
@pytest.mark.filterwarnings(
68+
"ignore:Model.model property is deprecated. Just use Model.:FutureWarning",
69+
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
70+
+ "To suppress this warning set `negate_output=False`:FutureWarning",
71+
)
72+
def test_laplace_only_fit():
73+
74+
# Example originates from Bayesian Data Analyses, 3rd Edition
75+
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
76+
# Aki Vehtari, and Donald Rubin.
77+
# See section. 4.1
78+
79+
y = np.array([2642, 3503, 4358], dtype=np.float64)
80+
n = y.size
81+
82+
with pm.Model() as m:
83+
logsigma = pm.Uniform("logsigma", 1, 100)
84+
mu = pm.Uniform("mu", -10000, 10000)
85+
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
86+
vars = [mu, logsigma]
87+
88+
idata = pmx.fit(
89+
method="laplace",
90+
vars=vars,
91+
model=m,
92+
random_seed=173300,
93+
)
94+
95+
assert idata.fit["mean_vector"].shape == (len(vars),)
96+
assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars))
97+
98+
bda_map = [y.mean(), np.log(y.std())]
99+
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]])
100+
101+
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
102+
assert np.allclose(
103+
idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)

0 commit comments

Comments
 (0)