Skip to content

Commit f8d0023

Browse files
danhphanmichaelosthege
authored andcommitted
refactor constant_data_to_xarray with find_constants
1 parent 4a2bf0e commit f8d0023

File tree

2 files changed

+32
-45
lines changed

2 files changed

+32
-45
lines changed

pymc/backends/arviz.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
Var = Any # pylint: disable=invalid-name
4444

4545

46-
def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
46+
def find_observations(model: Optional["Model"]) -> Dict[str, Var]:
4747
"""If there are observations available, return them as a dictionary."""
4848
if model is None:
49-
return None
49+
return {}
5050

5151
observations = {}
5252
for obs in model.observed_RVs:
@@ -63,20 +63,37 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
6363
return observations
6464

6565

66-
def find_constants(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
66+
def find_constants(model: Optional["Model"]) -> Dict[str, Var]:
6767
"""If there are constants available, return them as a dictionary."""
68-
if model is None or not model.named_vars:
69-
return None
68+
# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
69+
if model is None:
70+
return {}
71+
72+
def is_data(name, var, model) -> bool:
73+
observations = find_observations(model)
74+
return (
75+
var not in model.deterministics
76+
and var not in model.observed_RVs
77+
and var not in model.free_RVs
78+
and var not in model.potentials
79+
and var not in model.value_vars
80+
and name not in observations
81+
and isinstance(var, (Constant, SharedVariable))
82+
)
7083

71-
constants = {}
84+
# The assumption is that constants (like pm.Data) are named
85+
# variables that aren't observed or free RVs, nor are they
86+
# deterministics, and then we eliminate observations.
87+
constant_data = {}
7288
for name, var in model.named_vars.items():
73-
if isinstance(var, (Constant, SharedVariable)):
74-
if hasattr(var, "data"):
75-
var = var.data
76-
elif hasattr(var, "get_value"):
89+
if is_data(name, var, model):
90+
if hasattr(var, "get_value"):
7791
var = var.get_value()
78-
constants[name] = var
79-
return constants
92+
elif hasattr(var, "data"):
93+
var = var.data
94+
constant_data[name] = var
95+
96+
return constant_data
8097

8198

8299
class _DefaultTrace:
@@ -483,41 +500,10 @@ def observed_data_to_xarray(self):
483500
@requires("model")
484501
def constant_data_to_xarray(self):
485502
"""Convert constant data to xarray."""
486-
# For constant data, we are concerned only with deterministics and
487-
# data. The constant data vars must be either pm.Data
488-
# (TensorConstant/SharedVariable) or pm.Deterministic
489-
constant_data_vars = {} # type: Dict[str, Var]
490-
491-
def is_data(name, var) -> bool:
492-
assert self.model is not None
493-
return (
494-
var not in self.model.deterministics
495-
and var not in self.model.observed_RVs
496-
and var not in self.model.free_RVs
497-
and var not in self.model.potentials
498-
and var not in self.model.value_vars
499-
and (self.observations is None or name not in self.observations)
500-
and isinstance(var, (Constant, SharedVariable))
501-
)
502-
503-
# I don't know how to find pm.Data, except that they are named
504-
# variables that aren't observed or free RVs, nor are they
505-
# deterministics, and then we eliminate observations.
506-
for name, var in self.model.named_vars.items():
507-
if is_data(name, var):
508-
constant_data_vars[name] = var
509-
510-
if not constant_data_vars:
503+
constant_data = find_constants(self.model)
504+
if not constant_data:
511505
return None
512506

513-
constant_data = {}
514-
for name, vals in constant_data_vars.items():
515-
if hasattr(vals, "get_value"):
516-
vals = vals.get_value()
517-
elif hasattr(vals, "data"):
518-
vals = vals.data
519-
constant_data[name] = vals
520-
521507
return dict_to_dataset(
522508
constant_data,
523509
library=pymc,

pymc/sampling_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def sample_blackjax_nuts(
370370
posterior=posterior,
371371
log_likelihood=log_likelihood,
372372
observed_data=find_observations(model),
373+
constant_data=find_constants(model),
373374
coords=coords,
374375
dims=dims,
375376
attrs=make_attrs(attrs, library=blackjax),

0 commit comments

Comments
 (0)