From 97e2f753fe59cb55c885877f466575397990e487 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:19:32 +1100 Subject: [PATCH 1/3] Minor fix of blackjax import in fit_pathfinder function * Moved the import statement for blackjax to ensure it is only imported when needed. * Moved blackjax import statement prevents import errors for users on Windows. * Updated the fit function to specify the return type as az.InferenceData. --- pymc_extras/inference/fit.py | 3 ++- pymc_extras/inference/pathfinder/pathfinder.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index bb695113..60d89777 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import arviz as az -def fit(method, **kwargs): +def fit(method: str, **kwargs) -> az.InferenceData: """ Fit a model with an inference algorithm diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index dfe5fc6a..531efc56 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -21,11 +21,9 @@ from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass, field, replace from enum import Enum, auto -from importlib.util import find_spec from typing import Literal, TypeAlias import arviz as az -import blackjax import filelock import jax import numpy as np @@ -1736,8 +1734,8 @@ def fit_pathfinder( ) pathfinder_samples = mp_result.samples elif inference_backend == "blackjax": - if find_spec("blackjax") is None: - raise RuntimeError("Need BlackJAX to use `pathfinder`") + import blackjax + if version.parse(blackjax.__version__).major < 1: raise ImportError("fit_pathfinder requires blackjax 1.0 or above") From ab39ab17d98334be60ad2875a972b4a6726ab124 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Fri, 11 Apr 2025 00:22:22 +1000 Subject: [PATCH 2/3] Fix failed test from PR #443 and replaced batched_dot for einsum * Modified the calculation of in and to use instead of the deprecated function. * Added calculations for in to improve readability. * Renamed to in pathfinder functions to reflect API changes. * Updated warning filters to ignore UserWarnings related to einsum subscripts. --- .../inference/pathfinder/pathfinder.py | 25 +++++++++---------- tests/test_pathfinder.py | 4 ++- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 531efc56..dd99fff4 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -40,7 +40,7 @@ from pymc.model import modelcontext from pymc.model.core import Point from pymc.pytensorf import ( - compile_pymc, + compile, find_rng_nodes, reseed_rngs, ) @@ -77,7 +77,9 @@ logger = logging.getLogger(__name__) _warnings.filterwarnings( - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" + "ignore", + category=UserWarning, + message="The same einsum subscript is used for a broadcastable and non-broadcastable dimension", ) REGULARISATION_TERM = 1e-8 @@ -142,7 +144,7 @@ def get_logp_dlogp_of_ravel_inputs( [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)], model.value_vars, ) - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) logp_dlogp_fn.trust_input = True return logp_dlogp_fn @@ -502,7 +504,7 @@ def bfgs_sample_dense( logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) - mu = x - pt.batched_dot(H_inv, g) + mu = x - pt.einsum("ijk,ik->ij", H_inv, g) phi = pt.matrix_transpose( # (L, N, 1) @@ -571,15 +573,12 @@ def bfgs_sample_sparse( logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) logdet += pt.sum(pt.log(alpha), axis=-1) + # inverse Hessian + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) + # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version. - mu = x - ( - # (L, N), (L, N) -> (L, N) - pt.batched_dot(alpha_diag, g) - # beta @ gamma @ beta.T - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) - # (L, N, N), (L, N) -> (L, N) - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) - ) + mu = x - pt.einsum("ijk,ik->ij", H_inv, g) phi = pt.matrix_transpose( # (L, N, 1) @@ -853,7 +852,7 @@ def make_pathfinder_body( # return psi, logP_psi, logQ_psi, elbo_argmax - pathfinder_body_fn = compile_pymc( + pathfinder_body_fn = compile( [x_full, g_full], [psi, logP_psi, logQ_psi, elbo_argmax], **compile_kwargs, diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index af9213ff..a2aa958e 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -18,7 +18,9 @@ import pymc as pm import pytest -pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning") +pytestmark = pytest.mark.filterwarnings( + "ignore:The same einsum subscript is used for a broadcastable and non-broadcastable dimension:UserWarning" +) import pymc_extras as pmx From 530d74f8c2dbc66911f11f4cafba3f42eb239674 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sun, 13 Apr 2025 01:43:32 +1000 Subject: [PATCH 3/3] Fix Multiple destroyers error by replacing pt.einsum with pt.vectorize(pt.dot,...) * Fixed errors with deprecated einsum usage in bfgs_sample_dense and bfgs_sample_sparse functions by implementing pt.vectorize(pt.dot,...). * Updated test_pathfinder to filter out deprecation warnings related to JAXopt. --- pymc_extras/inference/pathfinder/pathfinder.py | 16 ++++++++-------- tests/test_pathfinder.py | 5 +---- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index dd99fff4..3817f9fa 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -15,7 +15,6 @@ import collections import logging import time -import warnings as _warnings from collections import Counter from collections.abc import Callable, Iterator @@ -76,11 +75,6 @@ ) logger = logging.getLogger(__name__) -_warnings.filterwarnings( - "ignore", - category=UserWarning, - message="The same einsum subscript is used for a broadcastable and non-broadcastable dimension", -) REGULARISATION_TERM = 1e-8 DEFAULT_LINKER = "cvm_nogc" @@ -504,7 +498,10 @@ def bfgs_sample_dense( logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) - mu = x - pt.einsum("ijk,ik->ij", H_inv, g) + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g + + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) phi = pt.matrix_transpose( # (L, N, 1) @@ -578,7 +575,10 @@ def bfgs_sample_sparse( H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version. - mu = x - pt.einsum("ijk,ik->ij", H_inv, g) + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g + + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) phi = pt.matrix_transpose( # (L, N, 1) diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index a2aa958e..b2f4b815 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -18,10 +18,6 @@ import pymc as pm import pytest -pytestmark = pytest.mark.filterwarnings( - "ignore:The same einsum subscript is used for a broadcastable and non-broadcastable dimension:UserWarning" -) - import pymc_extras as pmx @@ -55,6 +51,7 @@ def reference_idata(): @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"]) +@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning") def test_pathfinder(inference_backend, reference_idata): if inference_backend == "blackjax" and sys.platform == "win32": pytest.skip("JAX not supported on windows")