@@ -118,7 +118,6 @@ def __init__(
118
118
dims : Optional [DimSpec ] = None ,
119
119
model = None ,
120
120
save_warmup : Optional [bool ] = None ,
121
- density_dist_obs : bool = True ,
122
121
index_origin : Optional [int ] = None ,
123
122
):
124
123
@@ -190,28 +189,18 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
190
189
model_dims = {k : list (v ) for k , v in self .model .RV_dims .items ()}
191
190
self .dims = {** model_dims , ** self .dims }
192
191
193
- self .density_dist_obs = density_dist_obs
194
- self .observations , self .multi_observations = self .find_observations ()
192
+ self .observations = self .find_observations ()
195
193
196
- def find_observations (self ) -> Tuple [ Optional [Dict [str , Var ]], Optional [ Dict [ str , Var ] ]]:
194
+ def find_observations (self ) -> Optional [Dict [str , Var ]]:
197
195
"""If there are observations available, return them as a dictionary."""
198
196
if self .model is None :
199
- return ( None , None )
197
+ return None
200
198
observations = {}
201
- multi_observations = {}
202
199
for obs in self .model .observed_RVs :
203
- if hasattr (obs , "observations" ):
204
- aux_obs = obs .observations
205
- observations [obs .name ] = (
206
- aux_obs .get_value () if hasattr (aux_obs , "get_value" ) else aux_obs
207
- )
208
- elif hasattr (obs , "data" ) and self .density_dist_obs :
209
- for key , val in obs .data .items ():
210
- aux_obs = val .eval () if hasattr (val , "eval" ) else val
211
- multi_observations [key ] = (
212
- aux_obs .get_value () if hasattr (aux_obs , "get_value" ) else aux_obs
213
- )
214
- return observations , multi_observations
200
+ if hasattr (obs .tag , "observations" ):
201
+ aux_obs = obs .tag .observations
202
+ observations [obs .name ] = aux_obs .data if hasattr (aux_obs , "data" ) else aux_obs
203
+ return observations
215
204
216
205
def split_trace (self ) -> Tuple [Union [None , "MultiTrace" ], Union [None , "MultiTrace" ]]:
217
206
"""Split MultiTrace object into posterior and warmup.
@@ -233,41 +222,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
233
222
trace_posterior = self .trace [self .ntune :]
234
223
return trace_posterior , trace_warmup
235
224
236
- def log_likelihood_vals_point (self , point , var , log_like_fun ):
237
- """Compute log likelihood for each observed point."""
238
- log_like_val = utils .one_de (log_like_fun (point ))
239
- if var .missing_values :
240
- mask = var .observations .mask
241
- if np .ndim (mask ) > np .ndim (log_like_val ):
242
- mask = np .any (mask , axis = - 1 )
243
- log_like_val = np .where (mask , np .nan , log_like_val )
244
- return log_like_val
245
-
246
- def _extract_log_likelihood (self , trace ):
247
- """Compute log likelihood of each observation."""
248
- if self .trace is None :
249
- return None
250
- if self .model is None :
251
- return None
252
-
253
- if self .log_likelihood is True :
254
- cached = [(var , var .logp_elemwise ) for var in self .model .observed_RVs ]
255
- else :
256
- cached = [
257
- (var , var .logp_elemwise )
258
- for var in self .model .observed_RVs
259
- if var .name in self .log_likelihood
260
- ]
261
- log_likelihood_dict = _DefaultTrace (len (trace .chains ))
262
- for var , log_like_fun in cached :
263
- for k , chain in enumerate (trace .chains ):
264
- log_like_chain = [
265
- self .log_likelihood_vals_point (point , var , log_like_fun )
266
- for point in trace .points ([chain ])
267
- ]
268
- log_likelihood_dict .insert (var .name , np .stack (log_like_chain ), k )
269
- return log_likelihood_dict .trace_dict
270
-
271
225
@requires ("trace" )
272
226
def posterior_to_xarray (self ):
273
227
"""Convert the posterior to an xarray dataset."""
@@ -348,6 +302,8 @@ def sample_stats_to_xarray(self):
348
302
@requires ("model" )
349
303
def log_likelihood_to_xarray (self ):
350
304
"""Extract log likelihood and log_p data from PyMC3 trace."""
305
+ # TODO: add pointwise log likelihood extraction to the converter
306
+ return None
351
307
if self .predictions or not self .log_likelihood :
352
308
return None
353
309
data_warmup = {}
@@ -540,7 +496,6 @@ def to_inference_data(
540
496
dims : Optional [DimSpec ] = None ,
541
497
model : Optional ["Model" ] = None ,
542
498
save_warmup : Optional [bool ] = None ,
543
- density_dist_obs : bool = True ,
544
499
) -> InferenceData :
545
500
"""Convert pymc3 data into an InferenceData object.
546
501
@@ -590,7 +545,6 @@ def to_inference_data(
590
545
dims = dims ,
591
546
model = model ,
592
547
save_warmup = save_warmup ,
593
- density_dist_obs = density_dist_obs ,
594
548
).to_inference_data ()
595
549
596
550
0 commit comments