Skip to content

Commit 9cfc521

Browse files
committed
Adapt to major PyMC changes
1 parent 1b09f82 commit 9cfc521

File tree

7 files changed

+30
-13
lines changed

7 files changed

+30
-13
lines changed

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.17.0 # CI was failing to resolve
13+
- pymc>=5.19.1 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn
1616
- better_optimize>=0.0.10

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.17.0 # CI was failing to resolve
13+
- pymc>=5.19.1 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn
1616
- better_optimize>=0.0.10

pymc_experimental/distributions/timeseries.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,20 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
281281
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
282282
name = "discrete_markov_chain_gibbs_metropolis"
283283

284-
def __init__(self, vars, proposal="uniform", order="random", model=None):
284+
def __init__(
285+
self,
286+
vars,
287+
proposal="uniform",
288+
order="random",
289+
model=None,
290+
initial_point=None,
291+
compile_kwargs: dict | None = None,
292+
**kwargs,
293+
):
285294
model = pm.modelcontext(model)
286295
vars = get_value_vars_from_user_vars(vars, model)
287-
initial_point = model.initial_point()
296+
if initial_point is None:
297+
initial_point = model.initial_point()
288298

289299
dimcats = []
290300
# The above variable is a list of pairs (aggregate dimension, number
@@ -332,7 +342,9 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):
332342
self.tune = True
333343

334344
# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
335-
ArrayStep.__init__(self, vars, [model.compile_logp()])
345+
if compile_kwargs is None:
346+
compile_kwargs = {}
347+
ArrayStep.__init__(self, vars, [model.compile_logp(**compile_kwargs)], **kwargs)
336348

337349
@staticmethod
338350
def competence(var):

pymc_experimental/inference/laplace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,14 @@ def sample_laplace_posterior(
391391

392392
else:
393393
info = mu.point_map_info
394-
flat_shapes = [np.prod(shape).astype(int) for _, shape, _ in info]
394+
flat_shapes = [size for _, _, size, _ in info]
395395
slices = [
396396
slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
397397
]
398398

399399
posterior_draws = [
400400
posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
401-
for idx, (name, shape, dtype) in zip(slices, info)
401+
for idx, (name, shape, _, dtype) in zip(slices, info)
402402
]
403403

404404
idata = laplace_draws_to_inferencedata(posterior_draws, model)

pymc_experimental/inference/smc/sampling.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,16 @@ def arviz_from_particles(model, particles):
206206
-------
207207
"""
208208
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
209-
by_varname = {k.name: v.squeeze()[np.newaxis, :] for k, v in zip(model.value_vars, particles)}
209+
by_varname = {
210+
k.name: v.squeeze()[np.newaxis, :].astype(k.dtype)
211+
for k, v in zip(model.value_vars, particles)
212+
}
210213
varnames = [v.name for v in model.value_vars]
211214
with model:
212215
strace = NDArray(name=model.name)
213216
strace.setup(n_particles, 0)
214217
for particle_index in range(0, n_particles):
215-
strace.record(point={k: by_varname[k][0][particle_index] for k in varnames})
218+
strace.record(point={k: np.asarray(by_varname[k][0][particle_index]) for k in varnames})
216219
multitrace = MultiTrace((strace,))
217220
return to_inference_data(multitrace, log_likelihood=False)
218221

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pymc>=5.17.0
1+
pymc>=5.19.1
22
scikit-learn

tests/distributions/test_discrete_markov_chain.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,18 @@ def test_change_size_univariate(self):
225225
def test_mcmc_sampling(self):
226226
with pm.Model(coords={"step": range(100)}) as model:
227227
init_dist = Categorical.dist(p=[0.5, 0.5])
228-
DiscreteMarkovChain(
228+
markov_chain = DiscreteMarkovChain(
229229
"markov_chain",
230230
P=[[0.1, 0.9], [0.1, 0.9]],
231231
init_dist=init_dist,
232232
shape=(100,),
233233
dims="step",
234234
)
235235

236-
step_method = assign_step_methods(model)
237-
assert isinstance(step_method, DiscreteMarkovChainGibbsMetropolis)
236+
_, assigned_step_methods = assign_step_methods(model)
237+
assert assigned_step_methods[DiscreteMarkovChainGibbsMetropolis] == [
238+
model.rvs_to_values[markov_chain]
239+
]
238240

239241
# Sampler needs no tuning
240242
idata = pm.sample(

0 commit comments

Comments
 (0)