diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a179e59f..e218b506 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,6 @@ jobs: matrix: os: [ubuntu-18.04] floatx: [float32, float64] - test-subset: pymc-experimental/tests fail-fast: false runs-on: ${{ matrix.os }} env: @@ -81,7 +80,6 @@ jobs: matrix: os: [windows-latest] floatx: [float32, float64] - test-subset: pymc-experimental/tests fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 4567e0b8..0ecbe95d 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -11,4 +11,3 @@ dependencies: - xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" - - blackjax diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 19a8adc4..570af385 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -11,4 +11,3 @@ dependencies: - xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" - - blackjax diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 882ccda5..8946a3a1 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -11,6 +11,5 @@ _log.addHandler(handler) -from pymc_experimental import distributions, gp, inference, utils +from pymc_experimental import distributions, gp, utils from pymc_experimental.bart import * -from pymc_experimental.inference.fit import fit diff --git a/pymc_experimental/bart/__init__.py b/pymc_experimental/bart/__init__.py index 33127b7d..4b9379fe 100644 --- a/pymc_experimental/bart/__init__.py +++ b/pymc_experimental/bart/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The PyMC Developers +# Copyright 2020 The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py deleted file mode 100644 index 2339ccd4..00000000 --- a/pymc_experimental/inference/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from pymc_experimental.inference.fit import fit diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py deleted file mode 100644 index 6b9835c0..00000000 --- a/pymc_experimental/inference/fit.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2022 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def fit(method, **kwargs): - """ - Fit a model with an inference algorithm - - Parameters - ---------- - method : str - Which inference method to run. - Supported: pathfinder - - kwargs are passed on. - - Returns - ------- - arviz.InferenceData - """ - if method == "pathfinder": - try: - from pymc_experimental.inference.pathfinder import fit_pathfinder - except ImportError as exc: - raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc - return fit_pathfinder(**kwargs) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py deleted file mode 100644 index df2c616a..00000000 --- a/pymc_experimental/inference/pathfinder.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2022 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import sys -from typing import Optional - -import arviz as az -import blackjax -import jax -import jax.numpy as jnp -import jax.random as random -import numpy as np -import pymc as pm -from pymc import modelcontext -from pymc.sampling import RandomSeed, _get_seeds_per_chain -from pymc.sampling_jax import get_jaxified_graph -from pymc.util import get_default_varnames - - -def convert_flat_trace_to_idata( - samples, - dims=None, - coords=None, - include_transformed=False, - postprocessing_backend="cpu", - model=None, -): - - model = modelcontext(model) - init_position_dict = model.initial_point() - trace = collections.defaultdict(list) - astart = pm.blocking.DictToArrayBijection.map(init_position_dict) - for sample in samples: - raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info) - point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict) - for p, v in point.items(): - trace[p].append(v.tolist()) - - trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()} - - var_names = model.unobserved_value_vars - vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed)) - print("Transforming variables...", file=sys.stdout) - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = jax.vmap(jax.vmap(jax_fn))( - *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) - ) - - trace = {v.name: r for v, r in zip(vars_to_sample, result)} - idata = az.from_dict(trace, dims=dims, coords=coords) - - return idata - - -def fit_pathfinder( - iterations=5_000, - random_seed: Optional[RandomSeed] = None, - postprocessing_backend="cpu", - ftol=1e-4, - model=None, -): - """ - Fit the pathfinder algorithm as implemented in blackjax - - Requires the JAX backend - - Parameters - ---------- - iterations : int - Number of iterations to run. - random_seed : int - Random seed to set. - postprocessing_backend : str - Where to compute transformations of the trace. - "cpu" or "gpu". - ftol : float - Floating point tolerance - - Returns - ------- - arviz.InferenceData - - Reference - --------- - https://arxiv.org/abs/2108.03782 - """ - - (random_seed,) = _get_seeds_per_chain(random_seed, 1) - - model = modelcontext(model) - - rvs = [rv.name for rv in model.value_vars] - init_position_dict = model.initial_point() - init_position = [init_position_dict[rv] for rv in rvs] - - new_logprob, new_input = pm.aesaraf.join_nonshared_inputs( - init_position_dict, (model.logp(),), model.value_vars, () - ) - - logprob_fn_list = get_jaxified_graph([new_input], new_logprob) - - def logprob_fn(x): - return logprob_fn_list(x)[0] - - dim = sum(v.size for v in init_position_dict.values()) - - rng_key = random.PRNGKey(random_seed) - w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) - path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol) - - pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol) - state = pathfinder.init(w0) - - def inference_loop(rng_key, kernel, initial_state, num_samples): - @jax.jit - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_samples) - return jax.lax.scan(one_step, initial_state, keys) - - _, rng_key = random.split(rng_key) - print("Running pathfinder...", file=sys.stdout) - _, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations) - - dims = { - var_name: [dim for dim in dims if dim is not None] - for var_name, dims in model.RV_dims.items() - } - - idata = convert_flat_trace_to_idata( - samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims - ) - - return idata diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py deleted file mode 100644 index 6d92e0dd..00000000 --- a/pymc_experimental/tests/test_pathfinder.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2022 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import numpy as np -import pymc as pm -import pytest - -import pymc_experimental as pmx - - -@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.") -def test_pathfinder(): - # Data of the Eight Schools Model - J = 8 - y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) - sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) - - with pm.Model() as model: - - mu = pm.Normal("mu", mu=0.0, sigma=10.0) - tau = pm.HalfCauchy("tau", 5.0) - - theta = pm.Normal("theta", mu=0, sigma=1, shape=J) - theta_1 = mu + tau * theta - obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) - - idata = pmx.fit(method="pathfinder", iterations=100) - - assert idata is not None - assert "theta" in idata.posterior._variables.keys() - assert "tau" in idata.posterior._variables.keys() - assert "mu" in idata.posterior._variables.keys() - assert idata.posterior["mu"].shape == (1, 100) - assert idata.posterior["tau"].shape == (1, 100) - assert idata.posterior["theta"].shape == (1, 100, 8) diff --git a/requirements-dev.txt b/requirements-dev.txt index 4c58c77a..6c049906 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1 @@ dask[all] -blackjax