diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index 1eb23ff2..1f8e4531 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -84,7 +84,7 @@ class MarginalModel(Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.marginalized_rvs = [] - self._marginalized_named_vars_to_dims = treedict() + self._marginalized_named_vars_to_dims = {} def _delete_rv_mappings(self, rv: TensorVariable) -> None: """Remove all model mappings referring to rv