Skip to content

Commit cb6b864

Browse files
fixing second loading issue (#220)
* fixing second loading issue * adding ignores for warnings from numpy and arviz
1 parent 4051624 commit cb6b864

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

pymc_experimental/model_builder.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,21 @@ def save(self, fname: str) -> None:
381381
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
382382

383383
@classmethod
384-
def _convert_dims_to_tuple(cls, model_config: Dict) -> Dict:
384+
def _model_config_formatting(cls, model_config: Dict) -> Dict:
385+
"""
386+
Because of json serialization, model_config values that were originally tuples or numpy are being encoded as lists.
387+
This function converts them back to tuples and numpy arrays to ensure correct id encoding.
388+
"""
385389
for key in model_config:
386-
if (
387-
isinstance(model_config[key], dict)
388-
and "dims" in model_config[key]
389-
and isinstance(model_config[key]["dims"], list)
390-
):
391-
model_config[key]["dims"] = tuple(model_config[key]["dims"])
390+
if isinstance(model_config[key], dict):
391+
for sub_key in model_config[key]:
392+
if isinstance(model_config[key][sub_key], list):
393+
# Check if "dims" key to convert it to tuple
394+
if sub_key == "dims":
395+
model_config[key][sub_key] = tuple(model_config[key][sub_key])
396+
# Convert all other lists to numpy arrays
397+
else:
398+
model_config[key][sub_key] = np.array(model_config[key][sub_key])
392399
return model_config
393400

394401
@classmethod
@@ -420,7 +427,7 @@ def load(cls, fname: str):
420427
filepath = Path(str(fname))
421428
idata = az.from_netcdf(filepath)
422429
# needs to be converted, because json.loads was changing tuple to list
423-
model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"]))
430+
model_config = cls._model_config_formatting(json.loads(idata.attrs["model_config"]))
424431
model = cls(
425432
model_config=model_config,
426433
sampler_config=json.loads(idata.attrs["sampler_config"]),

pymc_experimental/tests/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from numpy.testing import Tester
16-
17-
test = Tester().test

pymc_experimental/tests/test_model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_convert_dims_to_tuple(fitted_model_instance):
160160
],
161161
},
162162
}
163-
converted_model_config = fitted_model_instance._convert_dims_to_tuple(model_config)
163+
converted_model_config = fitted_model_instance._model_config_formatting(model_config)
164164
assert converted_model_config["a"]["dims"] == ("x",)
165165

166166

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
[tool:pytest]
22
testpaths = tests
3+
filterwarnings =
4+
error
5+
ignore::DeprecationWarning:numpy.core.fromnumeric
6+
ignore:::arviz.*
7+
ignore:DeprecationWarning
38

49
[isort]
510
lines_between_types = 1

0 commit comments

Comments
 (0)