Skip to content

Commit 39736c1

Browse files
committed
Handle latest PyMC API
1 parent dfe3fe0 commit 39736c1

File tree

7 files changed

+32
-24
lines changed

7 files changed

+32
-24
lines changed

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.13.0 # CI was failing to resolve
13+
- pymc>=5.16.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.13.0 # CI was failing to resolve
13+
- pymc>=5.16.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

pymc_experimental/distributions/continuous.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
The imports from pymc are not fully replicated here: add imports as necessary.
2020
"""
2121

22-
from typing import List, Tuple, Union
22+
from typing import Tuple, Union
2323

2424
import numpy as np
2525
import pytensor.tensor as pt
@@ -37,8 +37,7 @@
3737

3838
class GenExtremeRV(RandomVariable):
3939
name: str = "Generalized Extreme Value"
40-
ndim_supp: int = 0
41-
ndims_params: List[int] = [0, 0, 0]
40+
signature = "(),(),()->()"
4241
dtype: str = "floatX"
4342
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
4443

@@ -275,7 +274,7 @@ def chi_dist(nu: TensorVariable, size: TensorVariable) -> TensorVariable:
275274

276275
def __new__(cls, name, nu, **kwargs):
277276
if "observed" not in kwargs:
278-
kwargs.setdefault("transform", transforms.log)
277+
kwargs.setdefault("default_transform", transforms.log)
279278
return CustomDist(name, nu, dist=cls.chi_dist, class_name="Chi", **kwargs)
280279

281280
@classmethod

pymc_experimental/distributions/discrete.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def log1mexp(x):
3131

3232
class GeneralizedPoissonRV(RandomVariable):
3333
name = "generalized_poisson"
34-
ndim_supp = 0
35-
ndims_params = [0, 0]
34+
signature = "(),()->()"
3635
dtype = "int64"
3736
_print_name = ("GeneralizedPoisson", "\\operatorname{GeneralizedPoisson}")
3837

pymc_experimental/model/marginal_model.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.graph.replace import graph_replace, vectorize_graph
2222
from pytensor.scan import map as scan_map
2323
from pytensor.tensor import TensorType, TensorVariable
24-
from pytensor.tensor.elemwise import Elemwise
24+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
2525
from pytensor.tensor.shape import Shape
2626
from pytensor.tensor.special import log_softmax
2727

@@ -598,7 +598,18 @@ def is_elemwise_subgraph(rv_to_marginalize, other_input_rvs, output_rvs):
598598
fg = FunctionGraph(outputs=output_rvs, clone=False)
599599

600600
non_elemwise_blockers = [
601-
o for node in fg.apply_nodes if not isinstance(node.op, Elemwise) for o in node.outputs
601+
o
602+
for node in fg.apply_nodes
603+
if not (
604+
isinstance(node.op, Elemwise)
605+
# Allow expand_dims on the left
606+
or (
607+
isinstance(node.op, DimShuffle)
608+
and not node.op.drop
609+
and node.op.shuffle == sorted(node.op.shuffle)
610+
)
611+
)
612+
for o in node.outputs
602613
]
603614
blocker_candidates = [rv_to_marginalize] + other_input_rvs + non_elemwise_blockers
604615
blockers = [var for var in blocker_candidates if var not in output_rvs]
@@ -698,16 +709,17 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
698709

699710
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
700711
op = rv.owner.op
712+
dist_params = rv.owner.op.dist_params(rv.owner)
701713
if isinstance(op, Bernoulli):
702714
return (0, 1)
703715
elif isinstance(op, Categorical):
704-
p_param = rv.owner.inputs[3]
716+
[p_param] = dist_params
705717
return tuple(range(pt.get_vector_length(p_param)))
706718
elif isinstance(op, DiscreteUniform):
707-
lower, upper = constant_fold(rv.owner.inputs[3:])
719+
lower, upper = constant_fold(dist_params)
708720
return tuple(np.arange(lower, upper + 1))
709721
elif isinstance(op, DiscreteMarkovChain):
710-
P = rv.owner.inputs[0]
722+
P, *_ = dist_params
711723
return tuple(range(pt.get_vector_length(P[-1])))
712724

713725
raise NotImplementedError(f"Cannot compute domain for op {op}")

pymc_experimental/tests/model/test_marginal_model.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import pytensor.tensor as pt
88
import pytest
99
from arviz import InferenceData, dict_to_dataset
10-
from pymc import ImputationWarning, inputvars
1110
from pymc.distributions import transforms
1211
from pymc.logprob.abstract import _logprob
1312
from pymc.model.fgraph import fgraph_from_model
13+
from pymc.pytensorf import inputvars
1414
from pymc.util import UNSET
1515
from scipy.special import log_softmax, logsumexp
1616
from scipy.stats import halfnorm, norm
@@ -45,9 +45,7 @@ def disaster_model():
4545
early_rate = pm.Exponential("early_rate", 1.0, initval=3)
4646
late_rate = pm.Exponential("late_rate", 1.0, initval=1)
4747
rate = pm.math.switch(switchpoint >= years, early_rate, late_rate)
48-
with pytest.warns(ImputationWarning), pytest.warns(
49-
RuntimeWarning, match="invalid value encountered in cast"
50-
):
48+
with pytest.warns(Warning):
5149
disasters = pm.Poisson("disasters", rate, observed=disaster_data)
5250

5351
return disaster_model, years
@@ -294,7 +292,7 @@ def test_recover_marginals_basic():
294292

295293
with m:
296294
prior = pm.sample_prior_predictive(
297-
samples=20,
295+
draws=20,
298296
random_seed=rng,
299297
return_inferencedata=False,
300298
)
@@ -337,7 +335,7 @@ def test_recover_marginals_coords():
337335

338336
with m:
339337
prior = pm.sample_prior_predictive(
340-
samples=20,
338+
draws=20,
341339
random_seed=rng,
342340
return_inferencedata=False,
343341
)
@@ -364,7 +362,7 @@ def test_recover_batched_marginal():
364362

365363
with m:
366364
prior = pm.sample_prior_predictive(
367-
samples=20,
365+
draws=20,
368366
random_seed=rng,
369367
return_inferencedata=False,
370368
)
@@ -394,7 +392,7 @@ def test_nested_recover_marginals():
394392

395393
with m:
396394
prior = pm.sample_prior_predictive(
397-
samples=20,
395+
draws=20,
398396
random_seed=rng,
399397
return_inferencedata=False,
400398
)
@@ -565,7 +563,7 @@ def test_marginalized_transforms(transform, expected_warning):
565563
w=w,
566564
comp_dists=pm.HalfNormal.dist([1, 2, 3]),
567565
initval=initval,
568-
transform=transform,
566+
default_transform=transform,
569567
)
570568
y = pm.Normal("y", 0, sigma, observed=data)
571569

@@ -583,7 +581,7 @@ def test_marginalized_transforms(transform, expected_warning):
583581
),
584582
),
585583
initval=initval,
586-
transform=transform,
584+
default_transform=transform,
587585
)
588586
y = pm.Normal("y", 0, sigma, observed=data)
589587

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pymc>=5.13.0
1+
pymc>=5.16.0
22
scikit-learn

0 commit comments

Comments
 (0)