diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index cba07fdf1c..fe76825471 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -58,29 +58,35 @@ def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=Fal warnings.warn( "The `save_trace` function will soon be removed." "Instead, use `arviz.to_netcdf` to save traces.", - DeprecationWarning, + FutureWarning, ) - if directory is None: - directory = ".pymc_{}.trace" - idx = 1 - while os.path.exists(directory.format(idx)): - idx += 1 - directory = directory.format(idx) - - if os.path.isdir(directory): - if overwrite: - shutil.rmtree(directory) - else: - raise OSError( - "Cautiously refusing to overwrite the already existing {}! Please supply " - "a different directory, or set `overwrite=True`".format(directory) - ) - os.makedirs(directory) - - for chain, ndarray in trace._straces.items(): - SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray) - return directory + if isinstance(trace, MultiTrace): + if directory is None: + directory = ".pymc_{}.trace" + idx = 1 + while os.path.exists(directory.format(idx)): + idx += 1 + directory = directory.format(idx) + + if os.path.isdir(directory): + if overwrite: + shutil.rmtree(directory) + else: + raise OSError( + "Cautiously refusing to overwrite the already existing {}! Please supply " + "a different directory, or set `overwrite=True`".format(directory) + ) + os.makedirs(directory) + + for chain, ndarray in trace._straces.items(): + SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray) + return directory + else: + raise TypeError( + f"You are attempting to save an InferenceData object but this function " + "works only for MultiTrace objects. Use `arviz.to_netcdf` instead" + ) def load_trace(directory: str, model=None) -> MultiTrace: @@ -103,7 +109,7 @@ def load_trace(directory: str, model=None) -> MultiTrace: warnings.warn( "The `load_trace` function will soon be removed." "Instead, use `arviz.from_netcdf` to load traces.", - DeprecationWarning, + FutureWarning, ) straces = [] for subdir in glob.glob(os.path.join(directory, "*")): @@ -125,7 +131,7 @@ def __init__(self, directory: str): warnings.warn( "The `SerializeNDArray` class will soon be removed. " "Instead, use ArviZ to save/load traces.", - DeprecationWarning, + FutureWarning, ) self.directory = directory self.metadata_path = os.path.join(self.directory, self.metadata_file)