Skip to content

Commit face4b7

Browse files
committed
update conversion of observations and disable log likelihood
1 parent d46dae4 commit face4b7

File tree

1 file changed

+9
-55
lines changed

1 file changed

+9
-55
lines changed

pymc3/backends/arviz.py

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def __init__(
118118
dims: Optional[DimSpec] = None,
119119
model=None,
120120
save_warmup: Optional[bool] = None,
121-
density_dist_obs: bool = True,
122121
index_origin: Optional[int] = None,
123122
):
124123

@@ -190,28 +189,18 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
190189
model_dims = {k: list(v) for k, v in self.model.RV_dims.items()}
191190
self.dims = {**model_dims, **self.dims}
192191

193-
self.density_dist_obs = density_dist_obs
194-
self.observations, self.multi_observations = self.find_observations()
192+
self.observations = self.find_observations()
195193

196-
def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str, Var]]]:
194+
def find_observations(self) -> Optional[Dict[str, Var]]:
197195
"""If there are observations available, return them as a dictionary."""
198196
if self.model is None:
199-
return (None, None)
197+
return None
200198
observations = {}
201-
multi_observations = {}
202199
for obs in self.model.observed_RVs:
203-
if hasattr(obs, "observations"):
204-
aux_obs = obs.observations
205-
observations[obs.name] = (
206-
aux_obs.get_value() if hasattr(aux_obs, "get_value") else aux_obs
207-
)
208-
elif hasattr(obs, "data") and self.density_dist_obs:
209-
for key, val in obs.data.items():
210-
aux_obs = val.eval() if hasattr(val, "eval") else val
211-
multi_observations[key] = (
212-
aux_obs.get_value() if hasattr(aux_obs, "get_value") else aux_obs
213-
)
214-
return observations, multi_observations
200+
if hasattr(obs.tag, "observations"):
201+
aux_obs = obs.tag.observations
202+
observations[obs.name] = aux_obs.data if hasattr(aux_obs, "data") else aux_obs
203+
return observations
215204

216205
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
217206
"""Split MultiTrace object into posterior and warmup.
@@ -233,41 +222,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
233222
trace_posterior = self.trace[self.ntune :]
234223
return trace_posterior, trace_warmup
235224

236-
def log_likelihood_vals_point(self, point, var, log_like_fun):
237-
"""Compute log likelihood for each observed point."""
238-
log_like_val = utils.one_de(log_like_fun(point))
239-
if var.missing_values:
240-
mask = var.observations.mask
241-
if np.ndim(mask) > np.ndim(log_like_val):
242-
mask = np.any(mask, axis=-1)
243-
log_like_val = np.where(mask, np.nan, log_like_val)
244-
return log_like_val
245-
246-
def _extract_log_likelihood(self, trace):
247-
"""Compute log likelihood of each observation."""
248-
if self.trace is None:
249-
return None
250-
if self.model is None:
251-
return None
252-
253-
if self.log_likelihood is True:
254-
cached = [(var, var.logp_elemwise) for var in self.model.observed_RVs]
255-
else:
256-
cached = [
257-
(var, var.logp_elemwise)
258-
for var in self.model.observed_RVs
259-
if var.name in self.log_likelihood
260-
]
261-
log_likelihood_dict = _DefaultTrace(len(trace.chains))
262-
for var, log_like_fun in cached:
263-
for k, chain in enumerate(trace.chains):
264-
log_like_chain = [
265-
self.log_likelihood_vals_point(point, var, log_like_fun)
266-
for point in trace.points([chain])
267-
]
268-
log_likelihood_dict.insert(var.name, np.stack(log_like_chain), k)
269-
return log_likelihood_dict.trace_dict
270-
271225
@requires("trace")
272226
def posterior_to_xarray(self):
273227
"""Convert the posterior to an xarray dataset."""
@@ -348,6 +302,8 @@ def sample_stats_to_xarray(self):
348302
@requires("model")
349303
def log_likelihood_to_xarray(self):
350304
"""Extract log likelihood and log_p data from PyMC3 trace."""
305+
# TODO: add pointwise log likelihood extraction to the converter
306+
return None
351307
if self.predictions or not self.log_likelihood:
352308
return None
353309
data_warmup = {}
@@ -540,7 +496,6 @@ def to_inference_data(
540496
dims: Optional[DimSpec] = None,
541497
model: Optional["Model"] = None,
542498
save_warmup: Optional[bool] = None,
543-
density_dist_obs: bool = True,
544499
) -> InferenceData:
545500
"""Convert pymc3 data into an InferenceData object.
546501
@@ -590,7 +545,6 @@ def to_inference_data(
590545
dims=dims,
591546
model=model,
592547
save_warmup=save_warmup,
593-
density_dist_obs=density_dist_obs,
594548
).to_inference_data()
595549

596550

0 commit comments

Comments
 (0)