Skip to content

Commit 9fa3295

Browse files
committed
Add warning if numbers of variables in vars does not equal number of model variables
1 parent f8fc0e2 commit 9fa3295

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

pymc_experimental/inference/laplace.py

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from collections.abc import Sequence
1617
from typing import Optional
1718

@@ -98,6 +99,13 @@ def laplace(
9899
rng = np.random.default_rng(seed=random_seed)
99100

100101
transformed_m = pm.modelcontext(model)
102+
103+
if len(vars) != len(transformed_m.free_RVs):
104+
warnings.warn(
105+
"Number of variables in vars does not eqaul the number of variables in the model.",
106+
UserWarning,
107+
)
108+
101109
map = pm.find_MAP(vars=vars, progressbar=progressbar, model=transformed_m)
102110

103111
# See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html

pymc_experimental/tests/test_laplace.py

+39
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_laplace_only_fit():
8787
idata = pmx.fit(
8888
method="laplace",
8989
vars=vars,
90+
draws=None,
9091
model=m,
9192
random_seed=173300,
9293
)
@@ -99,3 +100,41 @@ def test_laplace_only_fit():
99100

100101
assert np.allclose(idata.fit["mean_vector"].values, bda_map)
101102
assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
103+
104+
105+
@pytest.mark.filterwarnings(
106+
"ignore:Model.model property is deprecated. Just use Model.:FutureWarning",
107+
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
108+
+ "To suppress this warning set `negate_output=False`:FutureWarning",
109+
)
110+
def test_laplace_subset_of_rv(recwarn):
111+
112+
# Example originates from Bayesian Data Analyses, 3rd Edition
113+
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
114+
# Aki Vehtari, and Donald Rubin.
115+
# See section. 4.1
116+
117+
y = np.array([2642, 3503, 4358], dtype=np.float64)
118+
n = y.size
119+
120+
with pm.Model() as m:
121+
logsigma = pm.Uniform("logsigma", 1, 100)
122+
mu = pm.Uniform("mu", -10000, 10000)
123+
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
124+
vars = [mu]
125+
126+
idata = pmx.fit(
127+
method="laplace",
128+
vars=vars,
129+
draws=None,
130+
model=m,
131+
random_seed=173300,
132+
)
133+
134+
assert len(recwarn) == 4
135+
w = recwarn.pop(UserWarning)
136+
assert issubclass(w.category, UserWarning)
137+
assert (
138+
str(w.message)
139+
== "Number of variables in vars does not eqaul the number of variables in the model."
140+
)

0 commit comments

Comments
 (0)