From 97e2f753fe59cb55c885877f466575397990e487 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sun, 30 Mar 2025 02:19:32 +1100 Subject: [PATCH] Minor fix of blackjax import in fit_pathfinder function * Moved the import statement for blackjax to ensure it is only imported when needed. * Moved blackjax import statement prevents import errors for users on Windows. * Updated the fit function to specify the return type as az.InferenceData. --- pymc_extras/inference/fit.py | 3 ++- pymc_extras/inference/pathfinder/pathfinder.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index bb695113..60d89777 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import arviz as az -def fit(method, **kwargs): +def fit(method: str, **kwargs) -> az.InferenceData: """ Fit a model with an inference algorithm diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index dfe5fc6a..531efc56 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -21,11 +21,9 @@ from collections.abc import Callable, Iterator from dataclasses import asdict, dataclass, field, replace from enum import Enum, auto -from importlib.util import find_spec from typing import Literal, TypeAlias import arviz as az -import blackjax import filelock import jax import numpy as np @@ -1736,8 +1734,8 @@ def fit_pathfinder( ) pathfinder_samples = mp_result.samples elif inference_backend == "blackjax": - if find_spec("blackjax") is None: - raise RuntimeError("Need BlackJAX to use `pathfinder`") + import blackjax + if version.parse(blackjax.__version__).major < 1: raise ImportError("fit_pathfinder requires blackjax 1.0 or above")