Skip to content

Commit 2b92a05

Browse files
aloctavodiatwiecki
authored andcommitted
fix deprecation messages
1 parent 9e7e8aa commit 2b92a05

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

pymc/backends/ndarray.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,35 @@ def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=Fal
5858
warnings.warn(
5959
"The `save_trace` function will soon be removed."
6060
"Instead, use `arviz.to_netcdf` to save traces.",
61-
DeprecationWarning,
61+
FutureWarning,
6262
)
6363

64-
if directory is None:
65-
directory = ".pymc_{}.trace"
66-
idx = 1
67-
while os.path.exists(directory.format(idx)):
68-
idx += 1
69-
directory = directory.format(idx)
70-
71-
if os.path.isdir(directory):
72-
if overwrite:
73-
shutil.rmtree(directory)
74-
else:
75-
raise OSError(
76-
"Cautiously refusing to overwrite the already existing {}! Please supply "
77-
"a different directory, or set `overwrite=True`".format(directory)
78-
)
79-
os.makedirs(directory)
80-
81-
for chain, ndarray in trace._straces.items():
82-
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
83-
return directory
64+
if isinstance(trace, MultiTrace):
65+
if directory is None:
66+
directory = ".pymc_{}.trace"
67+
idx = 1
68+
while os.path.exists(directory.format(idx)):
69+
idx += 1
70+
directory = directory.format(idx)
71+
72+
if os.path.isdir(directory):
73+
if overwrite:
74+
shutil.rmtree(directory)
75+
else:
76+
raise OSError(
77+
"Cautiously refusing to overwrite the already existing {}! Please supply "
78+
"a different directory, or set `overwrite=True`".format(directory)
79+
)
80+
os.makedirs(directory)
81+
82+
for chain, ndarray in trace._straces.items():
83+
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
84+
return directory
85+
else:
86+
raise TypeError(
87+
f"You are attempting to save an InferenceData object but this function "
88+
"works only for MultiTrace objects. Use `arviz.to_netcdf` instead"
89+
)
8490

8591

8692
def load_trace(directory: str, model=None) -> MultiTrace:
@@ -103,7 +109,7 @@ def load_trace(directory: str, model=None) -> MultiTrace:
103109
warnings.warn(
104110
"The `load_trace` function will soon be removed."
105111
"Instead, use `arviz.from_netcdf` to load traces.",
106-
DeprecationWarning,
112+
FutureWarning,
107113
)
108114
straces = []
109115
for subdir in glob.glob(os.path.join(directory, "*")):
@@ -125,7 +131,7 @@ def __init__(self, directory: str):
125131
warnings.warn(
126132
"The `SerializeNDArray` class will soon be removed. "
127133
"Instead, use ArviZ to save/load traces.",
128-
DeprecationWarning,
134+
FutureWarning,
129135
)
130136
self.directory = directory
131137
self.metadata_path = os.path.join(self.directory, self.metadata_file)

0 commit comments

Comments
 (0)