20
20
from aesara .graph .basic import Constant
21
21
from aesara .tensor .sharedvar import SharedVariable
22
22
from arviz import InferenceData , concat , rcParams
23
- from arviz .data .base import CoordSpec , DimSpec , dict_to_dataset , requires
23
+ from arviz .data .base import CoordSpec , DimSpec
24
+ from arviz .data .base import dict_to_dataset as _dict_to_dataset
25
+ from arviz .data .base import generate_dims_coords , make_attrs , requires
24
26
25
27
import pymc3
26
28
@@ -98,6 +100,37 @@ def insert(self, k: str, v, idx: int):
98
100
self .trace_dict [k ][idx , :] = v
99
101
100
102
103
+ def dict_to_dataset (
104
+ data ,
105
+ library = None ,
106
+ coords = None ,
107
+ dims = None ,
108
+ attrs = None ,
109
+ default_dims = None ,
110
+ skip_event_dims = None ,
111
+ index_origin = None ,
112
+ ):
113
+ """Temporal workaround for dict_to_dataset.
114
+
115
+ Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work.
116
+ 1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore,
117
+ also remove unnecessary imports
118
+ """
119
+ if default_dims is None :
120
+ return _dict_to_dataset (
121
+ data , library = library , coords = coords , dims = dims , skip_event_dims = skip_event_dims
122
+ )
123
+ else :
124
+ out_data = {}
125
+ for name , vals in data .items ():
126
+ vals = np .atleast_1d (vals )
127
+ val_dims = dims .get (name )
128
+ val_dims , coords = generate_dims_coords (vals .shape , name , dims = val_dims , coords = coords )
129
+ coords = {key : xr .IndexVariable ((key ,), data = coords [key ]) for key in val_dims }
130
+ out_data [name ] = xr .DataArray (vals , dims = val_dims , coords = coords )
131
+ return xr .Dataset (data_vars = out_data , attrs = make_attrs (library = library ))
132
+
133
+
101
134
class InferenceDataConverter : # pylint: disable=too-many-instance-attributes
102
135
"""Encapsulate InferenceData specific logic."""
103
136
@@ -196,14 +229,13 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
196
229
self .dims = {** model_dims , ** self .dims }
197
230
198
231
self .density_dist_obs = density_dist_obs
199
- self .observations , self . multi_observations = self .find_observations ()
232
+ self .observations = self .find_observations ()
200
233
201
- def find_observations (self ) -> Tuple [ Optional [Dict [str , Var ]], Optional [ Dict [ str , Var ] ]]:
234
+ def find_observations (self ) -> Optional [Dict [str , Var ]]:
202
235
"""If there are observations available, return them as a dictionary."""
203
236
if self .model is None :
204
- return ( None , None )
237
+ return None
205
238
observations = {}
206
- multi_observations = {}
207
239
for obs in self .model .observed_RVs :
208
240
aux_obs = getattr (obs .tag , "observations" , None )
209
241
if aux_obs is not None :
@@ -215,7 +247,7 @@ def find_observations(self) -> Tuple[Optional[Dict[str, Var]], Optional[Dict[str
215
247
else :
216
248
warnings .warn (f"No data for observation { obs } " )
217
249
218
- return observations , multi_observations
250
+ return observations
219
251
220
252
def split_trace (self ) -> Tuple [Union [None , "MultiTrace" ], Union [None , "MultiTrace" ]]:
221
253
"""Split MultiTrace object into posterior and warmup.
@@ -302,15 +334,15 @@ def posterior_to_xarray(self):
302
334
coords = self .coords ,
303
335
dims = self .dims ,
304
336
attrs = self .attrs ,
305
- # index_origin=self.index_origin,
337
+ index_origin = self .index_origin ,
306
338
),
307
339
dict_to_dataset (
308
340
data_warmup ,
309
341
library = pymc3 ,
310
342
coords = self .coords ,
311
343
dims = self .dims ,
312
344
attrs = self .attrs ,
313
- # index_origin=self.index_origin,
345
+ index_origin = self .index_origin ,
314
346
),
315
347
)
316
348
@@ -344,15 +376,15 @@ def sample_stats_to_xarray(self):
344
376
dims = None ,
345
377
coords = self .coords ,
346
378
attrs = self .attrs ,
347
- # index_origin=self.index_origin,
379
+ index_origin = self .index_origin ,
348
380
),
349
381
dict_to_dataset (
350
382
data_warmup ,
351
383
library = pymc3 ,
352
384
dims = None ,
353
385
coords = self .coords ,
354
386
attrs = self .attrs ,
355
- # index_origin=self.index_origin,
387
+ index_origin = self .index_origin ,
356
388
),
357
389
)
358
390
@@ -385,15 +417,15 @@ def log_likelihood_to_xarray(self):
385
417
dims = self .dims ,
386
418
coords = self .coords ,
387
419
skip_event_dims = True ,
388
- # index_origin=self.index_origin,
420
+ index_origin = self .index_origin ,
389
421
),
390
422
dict_to_dataset (
391
423
data_warmup ,
392
424
library = pymc3 ,
393
425
dims = self .dims ,
394
426
coords = self .coords ,
395
427
skip_event_dims = True ,
396
- # index_origin=self.index_origin,
428
+ index_origin = self .index_origin ,
397
429
),
398
430
)
399
431
@@ -415,11 +447,7 @@ def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
415
447
k ,
416
448
)
417
449
return dict_to_dataset (
418
- data ,
419
- library = pymc3 ,
420
- coords = self .coords ,
421
- # dims=self.dims,
422
- # index_origin=self.index_origin
450
+ data , library = pymc3 , coords = self .coords , dims = self .dims , index_origin = self .index_origin
423
451
)
424
452
425
453
@requires (["posterior_predictive" ])
@@ -454,25 +482,25 @@ def priors_to_xarray(self):
454
482
{k : np .expand_dims (self .prior [k ], 0 ) for k in var_names },
455
483
library = pymc3 ,
456
484
coords = self .coords ,
457
- # dims=self.dims,
458
- # index_origin=self.index_origin,
485
+ dims = self .dims ,
486
+ index_origin = self .index_origin ,
459
487
)
460
488
)
461
489
return priors_dict
462
490
463
- @requires ([ "observations" , "multi_observations" ] )
491
+ @requires ("observations" )
464
492
@requires ("model" )
465
493
def observed_data_to_xarray (self ):
466
494
"""Convert observed data to xarray."""
467
495
if self .predictions :
468
496
return None
469
497
return dict_to_dataset (
470
- { ** self .observations , ** self . multi_observations } ,
498
+ self .observations ,
471
499
library = pymc3 ,
472
500
coords = self .coords ,
473
- # dims=self.dims,
474
- # default_dims=[],
475
- # index_origin=self.index_origin,
501
+ dims = self .dims ,
502
+ default_dims = [],
503
+ index_origin = self .index_origin ,
476
504
)
477
505
478
506
@requires (["trace" , "predictions" ])
@@ -517,9 +545,9 @@ def is_data(name, var) -> bool:
517
545
constant_data ,
518
546
library = pymc3 ,
519
547
coords = self .coords ,
520
- # dims=self.dims,
521
- # default_dims=[],
522
- # index_origin=self.index_origin,
548
+ dims = self .dims ,
549
+ default_dims = [],
550
+ index_origin = self .index_origin ,
523
551
)
524
552
525
553
def to_inference_data (self ):
0 commit comments