Skip to content

Commit 8ffb951

Browse files
tvwengeraloctavodia
authored andcommitted
revert formatting
1 parent 8eaa9be commit 8ffb951

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pymc/smc/kernels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,13 @@ def initialize_population(self) -> dict[str, np.ndarray]:
193193
)
194194

195195
model = self.model
196+
196197
prior_expression = make_initial_point_expression(
197198
free_rvs=model.free_RVs,
198199
rvs_to_transforms=model.rvs_to_transforms,
199-
initval_strategies={},
200+
initval_strategies={
201+
**model.rvs_to_initial_values,
202+
},
200203
default_strategy="prior",
201204
return_transformed=True,
202205
)

tests/smc/test_smc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pymc as pm
2626

2727
from pymc.backends.base import MultiTrace
28+
from pymc.distributions.transforms import Ordered
2829
from pymc.pytensorf import floatX
2930
from pymc.smc.kernels import IMH, systematic_resampling
3031
from tests.helpers import assert_random_state_equal
@@ -269,6 +270,30 @@ def test_deprecated_abc_args(self):
269270
):
270271
pm.sample_smc(draws=10, chains=1, save_log_pseudolikelihood=True)
271272

273+
def test_ordered(self):
274+
"""
275+
Test that initial population respects custom initval, especially when applied
276+
to the Ordered transformation. Regression test for #7438.
277+
"""
278+
with pm.Model() as m:
279+
pm.Normal(
280+
"a",
281+
mu=0.0,
282+
sigma=1.0,
283+
size=(2,),
284+
transform=Ordered(),
285+
initval=[-1.0, 1.0],
286+
)
287+
288+
smc = IMH(model=m)
289+
out = smc.initialize_population()
290+
291+
# initial point should not include NaNs
292+
assert not np.any(np.isnan(out["a_ordered__"]))
293+
294+
# initial point should match for all particles
295+
assert np.all(out["a_ordered__"][0] == out["a_ordered__"])
296+
272297

273298
class TestMHKernel:
274299
def test_normal_model(self):

0 commit comments

Comments
 (0)