Skip to content

Commit 0359373

Browse files
rename pymc_experimental -> pymc_extras
1 parent 8de5346 commit 0359373

File tree

12 files changed

+12
-14
lines changed

12 files changed

+12
-14
lines changed

pymc_extras/model/marginal/distributions.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymc.logprob.abstract import MeasurableOp, _logprob
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10-
from pytensor import Variable
1110
from pytensor.compile.builders import OpFromGraph
1211
from pytensor.compile.mode import Mode
1312
from pytensor.graph import Op, vectorize_graph

pymc_experimental/sampling/mcmc.py renamed to pymc_extras/sampling/mcmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pymc.sampling.mcmc import sample
55
from pytensor.graph.rewriting.basic import GraphRewriter
66

7-
from pymc_experimental.sampling.optimizations.optimize import (
7+
from pymc_extras.sampling.optimizations.optimize import (
88
TAGS_TYPE,
99
optimize_model_for_mcmc_sampling,
1010
)

pymc_experimental/sampling/optimizations/__init__.py renamed to pymc_extras/sampling/optimizations/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# ruff: noqa: F401
22
# Add rewrites to the optimization DBs
3-
import pymc_experimental.sampling.optimizations.conjugacy
4-
import pymc_experimental.sampling.optimizations.summary_stats
53

6-
from pymc_experimental.sampling.optimizations.optimize import (
4+
from pymc_extras.sampling.optimizations import conjugacy, summary_stats
5+
from pymc_extras.sampling.optimizations.optimize import (
76
optimize_model_for_mcmc_sampling,
87
posterior_optimization_db,
98
)

pymc_experimental/sampling/optimizations/conjugacy.py renamed to pymc_extras/sampling/optimizations/conjugacy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from pytensor.tensor.elemwise import DimShuffle
1111
from pytensor.tensor.subtensor import _sum_grad_over_bcasted_dims as sum_bcasted_dims
1212

13-
from pymc_experimental.sampling.optimizations.conjugate_sampler import (
13+
from pymc_extras.sampling.optimizations.conjugate_sampler import (
1414
ConjugateRV,
1515
)
16-
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
16+
from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db
1717

1818

1919
def register_conjugacy_rewrites_variants(rewrite_fn, tracks=(ModelFreeRV,)):

pymc_experimental/sampling/optimizations/conjugate_sampler.py renamed to pymc_extras/sampling/optimizations/conjugate_sampler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pytensor.link.jax.linker import JAXLinker
1414
from pytensor.tensor.random.type import RandomGeneratorType
1515

16-
from pymc_experimental.utils.ofg import inline_ofg_outputs
16+
from pymc_extras.utils.ofg import inline_ofg_outputs
1717

1818

1919
class ConjugateRV(OpFromGraph, MeasurableOp):

pymc_experimental/sampling/optimizations/summary_stats.py renamed to pymc_extras/sampling/optimizations/summary_stats.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytensor.graph.fg import FunctionGraph
66
from pytensor.graph.rewriting.basic import node_rewriter
77

8-
from pymc_experimental.sampling.optimizations.optimize import posterior_optimization_db
8+
from pymc_extras.sampling.optimizations.optimize import posterior_optimization_db
99

1010

1111
@node_rewriter(tracks=[ModelObservedRV])

pymc_extras/utils/ofg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytensor.graph.replace import clone_replace
66

77

8-
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
8+
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> list[Variable]:
99
"""Inline the inner graph (outputs) of an OpFromGraph Op.
1010
1111
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"

tests/sampling/mcmc/test_mcmc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pymc.sampling.mcmc import sample
99
from pymc.step_methods import Slice
1010

11-
from pymc_experimental import opt_sample
11+
from pymc_extras import opt_sample
1212

1313

1414
def test_custom_step_raises():

tests/sampling/optimizations/test_conjugacy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pymc.model.transform.conditioning import remove_value_transforms
77
from pymc.sampling import draw
88

9-
from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV
10-
from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling
9+
from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling
10+
from pymc_extras.sampling.optimizations.conjugate_sampler import ConjugateRV
1111

1212

1313
@pytest.mark.parametrize("eager", [False, True])

tests/sampling/optimizations/test_summary_stats.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pymc.distributions import HalfNormal, Normal
44
from pymc.model.core import Model
55

6-
from pymc_experimental.sampling.optimizations.optimize import optimize_model_for_mcmc_sampling
6+
from pymc_extras.sampling.optimizations import optimize_model_for_mcmc_sampling
77

88

99
def test_summary_stats_normal():

0 commit comments

Comments
 (0)