From 26d4db013b0a9bed921289f396d2393f6d963153 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 7 Sep 2022 15:16:01 +0200 Subject: [PATCH 01/31] Add wrapper for running blackjax pathfinder. --- pymc_experimental/__init__.py | 2 +- pymc_experimental/inference/pathfinder.py | 90 ++++++++++++++++++++++ pymc_experimental/tests/test_pathfinder.py | 22 ++++++ 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 pymc_experimental/inference/pathfinder.py create mode 100644 pymc_experimental/tests/test_pathfinder.py diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 8946a3a1..74b88488 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -11,5 +11,5 @@ _log.addHandler(handler) -from pymc_experimental import distributions, gp, utils +from pymc_experimental import distributions, gp, utils, inference from pymc_experimental.bart import * diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py new file mode 100644 index 00000000..d9622c68 --- /dev/null +++ b/pymc_experimental/inference/pathfinder.py @@ -0,0 +1,90 @@ +import sys +import collections + +import arviz as az +import numpy as np +import pymc as pm +from pymc import modelcontext +from pymc.util import get_default_varnames +from pymc.sampling_jax import get_jaxified_logp, get_jaxified_graph +import jax +import jax.numpy as jnp +import jax.random as random +import blackjax + +def convert_flat_trace_to_idata( + samples, dims=None, coords=None, include_transformed=False, postprocessing_backend="cpu", model=None, +): + model = modelcontext(model) + init_position_dict = model.initial_point() + trace = collections.defaultdict(list) + astart = pm.blocking.DictToArrayBijection.map(init_position_dict) + for sample in samples: + raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info) + point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict) + for p, v in point.items(): + trace[p].append(v.tolist()) + + print("Creating trace...", file=sys.stdout) + trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()} + + var_names = model.unobserved_value_vars + vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed)) + print("Transforming variables...", file=sys.stdout) + jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) + result = jax.vmap(jax.vmap(jax_fn))( + *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0]) + ) + + trace = {v.name: r for v, r in zip(vars_to_sample, result)} + idata = az.from_dict(trace, dims=dims, coords=coords) + + return idata + + +def fit_pathfinder(iterations=5_000, model=None): + model = modelcontext(model) + + rvs = [rv.name for rv in model.value_vars] + init_position_dict = model.initial_point() + init_position = [init_position_dict[rv] for rv in rvs] + + new_logprob, new_input = pm.aesaraf.join_nonshared_inputs( + init_position_dict, (model.logp(),), model.value_vars, () + ) + + logprob_fn_list = get_jaxified_graph([new_input], new_logprob) + + def logprob_fn(x): + return logprob_fn_list(x)[0] + + dim = sum(v.size for v in init_position_dict.values()) + + rng_key = random.PRNGKey(314) + w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) + path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=1e-4) + + pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=1e-4) + state = pathfinder.init(w0) + + def inference_loop(rng_key, kernel, initial_state, num_samples): + @jax.jit + def one_step(state, rng_key): + state, info = kernel(rng_key, state) + return state, (state, info) + + keys = jax.random.split(rng_key, num_samples) + return jax.lax.scan(one_step, initial_state, keys) + + _, rng_key = random.split(rng_key) + print("Running pathfinder...", file=sys.stdout) + _, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations) + + dims = { + var_name: [dim for dim in dims if dim is not None] + for var_name, dims in model.RV_dims.items() + } + + idata = convert_flat_trace_to_idata(samples, coords=model.coords, dims=dims) + + return idata diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py new file mode 100644 index 00000000..a6f68b02 --- /dev/null +++ b/pymc_experimental/tests/test_pathfinder.py @@ -0,0 +1,22 @@ +import numpy as np +import pymc as pm +import pytest + +import pymc_experimental as pmx + +def test_pathfinder(): + # Data of the Eight Schools Model + J = 8 + y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) + sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) + + with pm.Model() as model: + + mu = pm.Normal("mu", mu=0.0, sigma=10.0) + tau = pm.HalfCauchy("tau", 5.0) + + theta = pm.Normal("theta", mu=0, sigma=1, shape=J) + theta_1 = mu + tau * theta + obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) + + idata = pmx.inference.fit_pathfinder() \ No newline at end of file From 9cc95ec2ce7df08b084364b37817a8b7e74ec331 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 7 Sep 2022 15:17:52 +0200 Subject: [PATCH 02/31] Run black. --- pymc_experimental/inference/pathfinder.py | 8 +++++++- pymc_experimental/tests/test_pathfinder.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index d9622c68..5fe7b55b 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -12,8 +12,14 @@ import jax.random as random import blackjax + def convert_flat_trace_to_idata( - samples, dims=None, coords=None, include_transformed=False, postprocessing_backend="cpu", model=None, + samples, + dims=None, + coords=None, + include_transformed=False, + postprocessing_backend="cpu", + model=None, ): model = modelcontext(model) init_position_dict = model.initial_point() diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index a6f68b02..c496931a 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -4,6 +4,7 @@ import pymc_experimental as pmx + def test_pathfinder(): # Data of the Eight Schools Model J = 8 @@ -19,4 +20,4 @@ def test_pathfinder(): theta_1 = mu + tau * theta obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) - idata = pmx.inference.fit_pathfinder() \ No newline at end of file + idata = pmx.inference.fit_pathfinder() From 1e8ad4a31d1f5d295e250f948ba4d88417599277 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 7 Sep 2022 15:26:08 +0200 Subject: [PATCH 03/31] Run precommit. --- pymc_experimental/__init__.py | 2 +- pymc_experimental/inference/pathfinder.py | 12 ++++++------ pymc_experimental/tests/test_pathfinder.py | 1 - 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index 74b88488..fd786611 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -11,5 +11,5 @@ _log.addHandler(handler) -from pymc_experimental import distributions, gp, utils, inference +from pymc_experimental import distributions, gp, inference, utils from pymc_experimental.bart import * diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 5fe7b55b..c74a512c 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -1,16 +1,16 @@ -import sys import collections +import sys import arviz as az +import blackjax +import jax +import jax.numpy as jnp +import jax.random as random import numpy as np import pymc as pm from pymc import modelcontext +from pymc.sampling_jax import get_jaxified_graph from pymc.util import get_default_varnames -from pymc.sampling_jax import get_jaxified_logp, get_jaxified_graph -import jax -import jax.numpy as jnp -import jax.random as random -import blackjax def convert_flat_trace_to_idata( diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index c496931a..77c5506f 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -1,6 +1,5 @@ import numpy as np import pymc as pm -import pytest import pymc_experimental as pmx From a4cf33977fe085eae81d4a88704a9f20a577246d Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 7 Sep 2022 15:29:16 +0200 Subject: [PATCH 04/31] Add blackjax to requirements. --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index c40987ae..8f607415 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ pymc>=4.0.1 xhistogram +blackjax From 43b6f8e950287aa366664f4eae6dd30de8bd68fb Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Wed, 7 Sep 2022 15:38:48 +0200 Subject: [PATCH 05/31] Do not make import optional. --- pymc_experimental/inference/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 pymc_experimental/inference/__init__.py diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py new file mode 100644 index 00000000..16b3ead7 --- /dev/null +++ b/pymc_experimental/inference/__init__.py @@ -0,0 +1 @@ +from pymc_experimental.inference.pathfinder import fit_pathfinder From 3a3b2d7531f0e7b5ae966c4238c2b2d89b794f8b Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 11:02:07 +0200 Subject: [PATCH 06/31] Add more kwargs. Add license. Improve tests. Add doc string. Add fit function. --- pymc_experimental/inference/__init__.py | 1 + pymc_experimental/inference/pathfinder.py | 50 ++++++++++++++++++++-- pymc_experimental/tests/test_pathfinder.py | 21 ++++++++- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py index 16b3ead7..f54d9535 100644 --- a/pymc_experimental/inference/__init__.py +++ b/pymc_experimental/inference/__init__.py @@ -1 +1,2 @@ +from pymc_experimental.inference.fit import fit from pymc_experimental.inference.pathfinder import fit_pathfinder diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index c74a512c..8e788af8 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -1,3 +1,17 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import collections import sys @@ -31,7 +45,6 @@ def convert_flat_trace_to_idata( for p, v in point.items(): trace[p].append(v.tolist()) - print("Creating trace...", file=sys.stdout) trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()} var_names = model.unobserved_value_vars @@ -48,7 +61,34 @@ def convert_flat_trace_to_idata( return idata -def fit_pathfinder(iterations=5_000, model=None): +def fit_pathfinder( + iterations=5_000, random_seed=312, postprocessing_backend="cpu", ftol=1e-4, model=None +): + """ + Fit the pathfinder algorithm as implemented in blackjax + + Requires the JAX backend + + Parameters + ---------- + iterations : int + Number of iterations to run. + random_seed : int + Random seed to set. + postprocessing_backend : str + Where to compute transformations of the trace. + "cpu" or "gpu". + ftol : float + Floating point tolerance + + Returns + ------- + arviz.InferenceData + + Reference + --------- + https://arxiv.org/abs/2108.03782 + """ model = modelcontext(model) rvs = [rv.name for rv in model.value_vars] @@ -66,7 +106,7 @@ def logprob_fn(x): dim = sum(v.size for v in init_position_dict.values()) - rng_key = random.PRNGKey(314) + rng_key = random.PRNGKey(random_seed) w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=1e-4) @@ -91,6 +131,8 @@ def one_step(state, rng_key): for var_name, dims in model.RV_dims.items() } - idata = convert_flat_trace_to_idata(samples, coords=model.coords, dims=dims) + idata = convert_flat_trace_to_idata( + samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims + ) return idata diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index 77c5506f..a7bbb649 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -1,3 +1,17 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import pymc as pm @@ -19,4 +33,9 @@ def test_pathfinder(): theta_1 = mu + tau * theta obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) - idata = pmx.inference.fit_pathfinder() + idata = pmx.inference.fit_pathfinder(iterations=100) + + assert idata is not None + assert "theta" in idata.posterior._variables.keys() + assert "tau" in idata.posterior._variables.keys() + assert "mu" in idata.posterior._variables.keys() From 74de0f9ee4bef8175a1b26d080ff3e6a7a937f73 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 11:05:57 +0200 Subject: [PATCH 07/31] Add fit function to base namespace. --- pymc_experimental/bart/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc_experimental/bart/__init__.py b/pymc_experimental/bart/__init__.py index 4b9379fe..3b44a0e1 100644 --- a/pymc_experimental/bart/__init__.py +++ b/pymc_experimental/bart/__init__.py @@ -20,6 +20,7 @@ plot_variable_importance, predict, ) +from pymc_experimental.inference.fit import fit __all__ = ["BART", "PGBART"] From d4e9ab4d0ce82b108a3c7c43f0ca522d96bce2c4 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 11:07:13 +0200 Subject: [PATCH 08/31] Update copyright year. --- pymc_experimental/bart/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/bart/__init__.py b/pymc_experimental/bart/__init__.py index 3b44a0e1..fe0b4c02 100644 --- a/pymc_experimental/bart/__init__.py +++ b/pymc_experimental/bart/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2020 The PyMC Developers +# Copyright 2022 The PyMC Developers # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 1bf3473e4f217cff38335218316617f7a07b6606 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 15:39:56 +0200 Subject: [PATCH 09/31] Add type to random_seed and better init. Test for correct shapes. --- pymc_experimental/inference/__init__.py | 1 - pymc_experimental/inference/pathfinder.py | 10 +++++++++- pymc_experimental/tests/test_pathfinder.py | 3 +++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py index f54d9535..16b3ead7 100644 --- a/pymc_experimental/inference/__init__.py +++ b/pymc_experimental/inference/__init__.py @@ -1,2 +1 @@ -from pymc_experimental.inference.fit import fit from pymc_experimental.inference.pathfinder import fit_pathfinder diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 8e788af8..678847aa 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -14,6 +14,7 @@ import collections import sys +from typing import Optional import arviz as az import blackjax @@ -23,6 +24,7 @@ import numpy as np import pymc as pm from pymc import modelcontext +from pymc.sampling import RandomSeed, _get_seeds_per_chain from pymc.sampling_jax import get_jaxified_graph from pymc.util import get_default_varnames @@ -62,7 +64,11 @@ def convert_flat_trace_to_idata( def fit_pathfinder( - iterations=5_000, random_seed=312, postprocessing_backend="cpu", ftol=1e-4, model=None + iterations=5_000, + random_seed: Optional[RandomSeed] = None, + postprocessing_backend="cpu", + ftol=1e-4, + model=None, ): """ Fit the pathfinder algorithm as implemented in blackjax @@ -89,6 +95,8 @@ def fit_pathfinder( --------- https://arxiv.org/abs/2108.03782 """ + (random_seed,) = _get_seeds_per_chain(random_seed, 1) + model = modelcontext(model) rvs = [rv.name for rv in model.value_vars] diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index a7bbb649..7e7f13c5 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -39,3 +39,6 @@ def test_pathfinder(): assert "theta" in idata.posterior._variables.keys() assert "tau" in idata.posterior._variables.keys() assert "mu" in idata.posterior._variables.keys() + assert idata.posterior["mu"].shape == (1, 100) + assert idata.posterior["tau"].shape == (1, 100) + assert idata.posterior["theta"].shape == (1, 100, 8) From 79e89df861e60bb3d07ca87a82ee04f2bc4ed369 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 17:40:04 +0200 Subject: [PATCH 10/31] Update pymc_experimental/inference/pathfinder.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc_experimental/inference/pathfinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 678847aa..bba6535e 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -116,7 +116,7 @@ def logprob_fn(x): rng_key = random.PRNGKey(random_seed) w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) - path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=1e-4) + path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol) pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=1e-4) state = pathfinder.init(w0) From 03406cc179366a0b2cdc6e654310726053f4228e Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 17:40:13 +0200 Subject: [PATCH 11/31] Update pymc_experimental/inference/pathfinder.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc_experimental/inference/pathfinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index bba6535e..323dc6c4 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -118,7 +118,7 @@ def logprob_fn(x): w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim)) path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol) - pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=1e-4) + pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol) state = pathfinder.init(w0) def inference_loop(rng_key, kernel, initial_state, num_samples): From 4f0dc4e7372a6b443a3665c61a7bc47665412742 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 18:00:20 +0200 Subject: [PATCH 12/31] Fix import of fit function. --- pymc_experimental/__init__.py | 1 + pymc_experimental/bart/__init__.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/__init__.py b/pymc_experimental/__init__.py index fd786611..882ccda5 100644 --- a/pymc_experimental/__init__.py +++ b/pymc_experimental/__init__.py @@ -13,3 +13,4 @@ from pymc_experimental import distributions, gp, inference, utils from pymc_experimental.bart import * +from pymc_experimental.inference.fit import fit diff --git a/pymc_experimental/bart/__init__.py b/pymc_experimental/bart/__init__.py index fe0b4c02..33127b7d 100644 --- a/pymc_experimental/bart/__init__.py +++ b/pymc_experimental/bart/__init__.py @@ -20,7 +20,6 @@ plot_variable_importance, predict, ) -from pymc_experimental.inference.fit import fit __all__ = ["BART", "PGBART"] From 2cfeed9d7167538754a3d2b86e061f1eada9f892 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 18:01:03 +0200 Subject: [PATCH 13/31] Add fit.py. --- pymc_experimental/inference/fit.py | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 pymc_experimental/inference/fit.py diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py new file mode 100644 index 00000000..709c41a1 --- /dev/null +++ b/pymc_experimental/inference/fit.py @@ -0,0 +1,35 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pymc_experimental.inference import fit_pathfinder + + +def fit(method, *kwargs): + """ + Fit a model with an inference algorithm + + Parameters + ---------- + method : str + Which inference method to run. + Supported: pathfinder + + kwargs are passed on. + + Returns + ------- + arviz.InferenceData + """ + if method == "pathfinder": + return fit_pathfinder(**kwargs) From 746d1d160baa290e951d47b0d22d72acbf63ada7 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 22:27:16 +0200 Subject: [PATCH 14/31] Skip on windows. --- pymc_experimental/tests/test_pathfinder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index 7e7f13c5..a4af160e 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + import numpy as np import pymc as pm +import pytest import pymc_experimental as pmx +@pytest.skip(sys.platform == "win32") def test_pathfinder(): # Data of the Eight Schools Model J = 8 From 75e13589aebc3b9a40264dd60a1205b5a22ee5aa Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 8 Sep 2022 22:38:33 +0200 Subject: [PATCH 15/31] Skip on windows. Move imports inside so that we do not error on windows. --- pymc_experimental/inference/pathfinder.py | 11 +++++++---- pymc_experimental/tests/test_pathfinder.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 323dc6c4..3e94aada 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -17,10 +17,6 @@ from typing import Optional import arviz as az -import blackjax -import jax -import jax.numpy as jnp -import jax.random as random import numpy as np import pymc as pm from pymc import modelcontext @@ -37,6 +33,8 @@ def convert_flat_trace_to_idata( postprocessing_backend="cpu", model=None, ): + import jax + model = modelcontext(model) init_position_dict = model.initial_point() trace = collections.defaultdict(list) @@ -95,6 +93,11 @@ def fit_pathfinder( --------- https://arxiv.org/abs/2108.03782 """ + import blackjax + import jax + import jax.numpy as jnp + import jax.random as random + (random_seed,) = _get_seeds_per_chain(random_seed, 1) model = modelcontext(model) diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index a4af160e..334a99fb 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -21,7 +21,7 @@ import pymc_experimental as pmx -@pytest.skip(sys.platform == "win32") +@pytest.mark.xfail(sys.platform == "win32", reason="JAX not supported on windows.") def test_pathfinder(): # Data of the Eight Schools Model J = 8 From ca2001d514f528c4430f19b950418526e3ea97d6 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 12:51:51 +0200 Subject: [PATCH 16/31] skipif instead of xfail. --- pymc_experimental/tests/test_pathfinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index 334a99fb..350e8a81 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -21,7 +21,7 @@ import pymc_experimental as pmx -@pytest.mark.xfail(sys.platform == "win32", reason="JAX not supported on windows.") +@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.") def test_pathfinder(): # Data of the Eight Schools Model J = 8 From b6d284315da1a887201b916bc12fbd4f64e8f285 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 15:07:29 +0200 Subject: [PATCH 17/31] try/except blackjax import. --- pymc_experimental/inference/pathfinder.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 3e94aada..964dc45d 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -12,6 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + +try: + import blackjax + import jax + import jax.numpy as jnp + import jax.random as random +except ImportError: + warnings.warn("Can't import blackjax. Pathfinder will not be available.") + import collections import sys from typing import Optional @@ -33,7 +43,6 @@ def convert_flat_trace_to_idata( postprocessing_backend="cpu", model=None, ): - import jax model = modelcontext(model) init_position_dict = model.initial_point() @@ -93,10 +102,6 @@ def fit_pathfinder( --------- https://arxiv.org/abs/2108.03782 """ - import blackjax - import jax - import jax.numpy as jnp - import jax.random as random (random_seed,) = _get_seeds_per_chain(random_seed, 1) From 73f9e5cf4e0e3ff167f8ef781871b3b38e673223 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 15:21:24 +0200 Subject: [PATCH 18/31] try/except blackjax import. --- pymc_experimental/inference/pathfinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 964dc45d..1d91a2a2 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -19,6 +19,7 @@ import jax import jax.numpy as jnp import jax.random as random + from pymc.sampling_jax import get_jaxified_graph except ImportError: warnings.warn("Can't import blackjax. Pathfinder will not be available.") @@ -31,7 +32,6 @@ import pymc as pm from pymc import modelcontext from pymc.sampling import RandomSeed, _get_seeds_per_chain -from pymc.sampling_jax import get_jaxified_graph from pymc.util import get_default_varnames From cf3d0de961bea3a67c884bc345f1b2b09b2ad6ce Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:20:06 +0200 Subject: [PATCH 19/31] Update pymc_experimental/inference/fit.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc_experimental/inference/fit.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 709c41a1..963cc690 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -32,4 +32,8 @@ def fit(method, *kwargs): arviz.InferenceData """ if method == "pathfinder": + try: + from pymc_experimental.inference import fit_pathfinder + except ImportError as exc: + raise RuntimeError("Need JAX/ Blackjax / wahever to use `pathfinder`") from exc return fit_pathfinder(**kwargs) From 6deca7b3965d58df3eb68538517ed6184cc6d590 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:20:13 +0200 Subject: [PATCH 20/31] Update pymc_experimental/inference/fit.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc_experimental/inference/fit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 963cc690..b3a08284 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymc_experimental.inference import fit_pathfinder def fit(method, *kwargs): From 0bbfb560adcc9e52d05241960a5d05cf1ca2e7a6 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:21:03 +0200 Subject: [PATCH 21/31] Move blackjax to dev reqs. --- requirements-dev.txt | 1 + requirements.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6c049906..4c58c77a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1 +1,2 @@ dask[all] +blackjax diff --git a/requirements.txt b/requirements.txt index 8f607415..c40987ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ pymc>=4.0.1 xhistogram -blackjax From 7f7fcb3095883c0c3dce33d30ace4accf5ff1be0 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:22:54 +0200 Subject: [PATCH 22/31] Make import non-optional. --- pymc_experimental/inference/pathfinder.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pymc_experimental/inference/pathfinder.py b/pymc_experimental/inference/pathfinder.py index 1d91a2a2..df2c616a 100644 --- a/pymc_experimental/inference/pathfinder.py +++ b/pymc_experimental/inference/pathfinder.py @@ -12,26 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings - -try: - import blackjax - import jax - import jax.numpy as jnp - import jax.random as random - from pymc.sampling_jax import get_jaxified_graph -except ImportError: - warnings.warn("Can't import blackjax. Pathfinder will not be available.") import collections import sys from typing import Optional import arviz as az +import blackjax +import jax +import jax.numpy as jnp +import jax.random as random import numpy as np import pymc as pm from pymc import modelcontext from pymc.sampling import RandomSeed, _get_seeds_per_chain +from pymc.sampling_jax import get_jaxified_graph from pymc.util import get_default_varnames From e2dad01dbd4f359c5ab9a57c5bdc0219c13a1f32 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:25:41 +0200 Subject: [PATCH 23/31] Precommit. --- pymc_experimental/inference/fit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index b3a08284..9053fab4 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -13,7 +13,6 @@ # limitations under the License. - def fit(method, *kwargs): """ Fit a model with an inference algorithm From 4fc5e89e8c8164ffe7fd0ea422ccb3c4871e791a Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:30:19 +0200 Subject: [PATCH 24/31] Change imports. --- pymc_experimental/inference/__init__.py | 2 +- pymc_experimental/inference/fit.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py index 16b3ead7..e504b5cb 100644 --- a/pymc_experimental/inference/__init__.py +++ b/pymc_experimental/inference/__init__.py @@ -1 +1 @@ -from pymc_experimental.inference.pathfinder import fit_pathfinder +from pymc_experimental.inference import fit diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 9053fab4..34b886b1 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -31,7 +31,7 @@ def fit(method, *kwargs): """ if method == "pathfinder": try: - from pymc_experimental.inference import fit_pathfinder + from pymc_experimental.inference.pathfinder import fit_pathfinder except ImportError as exc: - raise RuntimeError("Need JAX/ Blackjax / wahever to use `pathfinder`") from exc + raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc return fit_pathfinder(**kwargs) From 7bdb10a4efbc7f2d6202a80abd47ce338d6789da Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:31:27 +0200 Subject: [PATCH 25/31] Call fit() from test. --- pymc_experimental/tests/test_pathfinder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index 350e8a81..63bece11 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -37,7 +37,7 @@ def test_pathfinder(): theta_1 = mu + tau * theta obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) - idata = pmx.inference.fit_pathfinder(iterations=100) + idata = pmx.inference.fit(method="pathfinder", iterations=100) assert idata is not None assert "theta" in idata.posterior._variables.keys() From 8f07c25f7742c1f123b5770079bf742d03a678a6 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:38:47 +0200 Subject: [PATCH 26/31] Fix fit import. --- pymc_experimental/inference/__init__.py | 2 +- pymc_experimental/tests/test_pathfinder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc_experimental/inference/__init__.py b/pymc_experimental/inference/__init__.py index e504b5cb..2339ccd4 100644 --- a/pymc_experimental/inference/__init__.py +++ b/pymc_experimental/inference/__init__.py @@ -1 +1 @@ -from pymc_experimental.inference import fit +from pymc_experimental.inference.fit import fit diff --git a/pymc_experimental/tests/test_pathfinder.py b/pymc_experimental/tests/test_pathfinder.py index 63bece11..6d92e0dd 100644 --- a/pymc_experimental/tests/test_pathfinder.py +++ b/pymc_experimental/tests/test_pathfinder.py @@ -37,7 +37,7 @@ def test_pathfinder(): theta_1 = mu + tau * theta obs = pm.Normal("obs", mu=theta, sigma=sigma, shape=J, observed=y) - idata = pmx.inference.fit(method="pathfinder", iterations=100) + idata = pmx.fit(method="pathfinder", iterations=100) assert idata is not None assert "theta" in idata.posterior._variables.keys() From 350b77d9e20257bb527b895d45028891d5fe4bbe Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:48:08 +0200 Subject: [PATCH 27/31] Fix kwargs. --- pymc_experimental/inference/fit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 34b886b1..6b9835c0 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -13,7 +13,7 @@ # limitations under the License. -def fit(method, *kwargs): +def fit(method, **kwargs): """ Fit a model with an inference algorithm From 7734b03f8dd9c5a35742273052d9fdfabace48c0 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 16:57:40 +0200 Subject: [PATCH 28/31] Add blackjax to test env. --- conda-envs/environment-test-py38.yml | 1 + conda-envs/environment-test-py39.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 0ecbe95d..4567e0b8 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -11,3 +11,4 @@ dependencies: - xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" + - blackjax diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 570af385..19a8adc4 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -11,3 +11,4 @@ dependencies: - xhistogram - pip: - "git+https://github.com/pymc-devs/pymc.git@main" + - blackjax From 625caf03100e69f2081589147e117276415f0cc6 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 17:17:38 +0200 Subject: [PATCH 29/31] Only look for tests in test subdir. --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e218b506..05743f8f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,7 +68,7 @@ jobs: - name: Run tests run: | conda activate pymc-test-py38 - python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET + python -m pytest -vv --cov=pymc_experimental/tests --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET - name: Upload coverage to Codecov uses: codecov/codecov-action@v2 with: From 2d0a43556601c19e40134511780938e6ad18339d Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 17:23:12 +0200 Subject: [PATCH 30/31] Only look for tests in test subdir. --- .github/workflows/test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 05743f8f..992e245c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,6 +19,7 @@ jobs: matrix: os: [ubuntu-18.04] floatx: [float32, float64] + test-subset: pymc-experimental/tests fail-fast: false runs-on: ${{ matrix.os }} env: @@ -68,7 +69,7 @@ jobs: - name: Run tests run: | conda activate pymc-test-py38 - python -m pytest -vv --cov=pymc_experimental/tests --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET + python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET - name: Upload coverage to Codecov uses: codecov/codecov-action@v2 with: From 0dca76c47e6328d6233fb3708a6e3c9006faabf2 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Fri, 9 Sep 2022 17:24:01 +0200 Subject: [PATCH 31/31] Only look for tests in test subdir. --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 992e245c..a179e59f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -81,6 +81,7 @@ jobs: matrix: os: [windows-latest] floatx: [float32, float64] + test-subset: pymc-experimental/tests fail-fast: false runs-on: ${{ matrix.os }} env: