diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index b81e7b461..992051fc9 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -15,6 +15,7 @@ import hashlib import json +import warnings from abc import abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -496,7 +497,13 @@ def fit( X_df = pd.DataFrame(X, columns=X.columns) combined_data = pd.concat([X_df, y], axis=1) assert all(combined_data.columns), "All columns must have non-empty names" - self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore return self.idata # type: ignore