@@ -58,29 +58,35 @@ def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=Fal
58
58
warnings .warn (
59
59
"The `save_trace` function will soon be removed."
60
60
"Instead, use `arviz.to_netcdf` to save traces." ,
61
- DeprecationWarning ,
61
+ FutureWarning ,
62
62
)
63
63
64
- if directory is None :
65
- directory = ".pymc_{}.trace"
66
- idx = 1
67
- while os .path .exists (directory .format (idx )):
68
- idx += 1
69
- directory = directory .format (idx )
70
-
71
- if os .path .isdir (directory ):
72
- if overwrite :
73
- shutil .rmtree (directory )
74
- else :
75
- raise OSError (
76
- "Cautiously refusing to overwrite the already existing {}! Please supply "
77
- "a different directory, or set `overwrite=True`" .format (directory )
78
- )
79
- os .makedirs (directory )
80
-
81
- for chain , ndarray in trace ._straces .items ():
82
- SerializeNDArray (os .path .join (directory , str (chain ))).save (ndarray )
83
- return directory
64
+ if isinstance (trace , MultiTrace ):
65
+ if directory is None :
66
+ directory = ".pymc_{}.trace"
67
+ idx = 1
68
+ while os .path .exists (directory .format (idx )):
69
+ idx += 1
70
+ directory = directory .format (idx )
71
+
72
+ if os .path .isdir (directory ):
73
+ if overwrite :
74
+ shutil .rmtree (directory )
75
+ else :
76
+ raise OSError (
77
+ "Cautiously refusing to overwrite the already existing {}! Please supply "
78
+ "a different directory, or set `overwrite=True`" .format (directory )
79
+ )
80
+ os .makedirs (directory )
81
+
82
+ for chain , ndarray in trace ._straces .items ():
83
+ SerializeNDArray (os .path .join (directory , str (chain ))).save (ndarray )
84
+ return directory
85
+ else :
86
+ raise TypeError (
87
+ f"You are attempting to save an InferenceData object but this function "
88
+ "works only for MultiTrace objects. Use `arviz.to_netcdf` instead"
89
+ )
84
90
85
91
86
92
def load_trace (directory : str , model = None ) -> MultiTrace :
@@ -103,7 +109,7 @@ def load_trace(directory: str, model=None) -> MultiTrace:
103
109
warnings .warn (
104
110
"The `load_trace` function will soon be removed."
105
111
"Instead, use `arviz.from_netcdf` to load traces." ,
106
- DeprecationWarning ,
112
+ FutureWarning ,
107
113
)
108
114
straces = []
109
115
for subdir in glob .glob (os .path .join (directory , "*" )):
@@ -125,7 +131,7 @@ def __init__(self, directory: str):
125
131
warnings .warn (
126
132
"The `SerializeNDArray` class will soon be removed. "
127
133
"Instead, use ArviZ to save/load traces." ,
128
- DeprecationWarning ,
134
+ FutureWarning ,
129
135
)
130
136
self .directory = directory
131
137
self .metadata_path = os .path .join (self .directory , self .metadata_file )
0 commit comments