43
43
Var = Any # pylint: disable=invalid-name
44
44
45
45
46
- def find_observations (model : Optional [ "Model" ] ) -> Optional [ Dict [str , Var ] ]:
46
+ def find_observations (model : "Model" ) -> Dict [str , Var ]:
47
47
"""If there are observations available, return them as a dictionary."""
48
- if model is None :
49
- return None
50
-
51
48
observations = {}
52
49
for obs in model .observed_RVs :
53
50
aux_obs = getattr (obs .tag , "observations" , None )
@@ -63,6 +60,36 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
63
60
return observations
64
61
65
62
63
+ def find_constants (model : "Model" ) -> Dict [str , Var ]:
64
+ """If there are constants available, return them as a dictionary."""
65
+ # The constant data vars must be either pm.Data or TensorConstant or SharedVariable
66
+ def is_data (name , var , model ) -> bool :
67
+ observations = find_observations (model )
68
+ return (
69
+ var not in model .deterministics
70
+ and var not in model .observed_RVs
71
+ and var not in model .free_RVs
72
+ and var not in model .potentials
73
+ and var not in model .value_vars
74
+ and name not in observations
75
+ and isinstance (var , (Constant , SharedVariable ))
76
+ )
77
+
78
+ # The assumption is that constants (like pm.Data) are named
79
+ # variables that aren't observed or free RVs, nor are they
80
+ # deterministics, and then we eliminate observations.
81
+ constant_data = {}
82
+ for name , var in model .named_vars .items ():
83
+ if is_data (name , var , model ):
84
+ if hasattr (var , "get_value" ):
85
+ var = var .get_value ()
86
+ elif hasattr (var , "data" ):
87
+ var = var .data
88
+ constant_data [name ] = var
89
+
90
+ return constant_data
91
+
92
+
66
93
class _DefaultTrace :
67
94
"""
68
95
Utility for collecting samples into a dictionary.
@@ -467,41 +494,10 @@ def observed_data_to_xarray(self):
467
494
@requires ("model" )
468
495
def constant_data_to_xarray (self ):
469
496
"""Convert constant data to xarray."""
470
- # For constant data, we are concerned only with deterministics and
471
- # data. The constant data vars must be either pm.Data
472
- # (TensorConstant/SharedVariable) or pm.Deterministic
473
- constant_data_vars = {} # type: Dict[str, Var]
474
-
475
- def is_data (name , var ) -> bool :
476
- assert self .model is not None
477
- return (
478
- var not in self .model .deterministics
479
- and var not in self .model .observed_RVs
480
- and var not in self .model .free_RVs
481
- and var not in self .model .potentials
482
- and var not in self .model .value_vars
483
- and (self .observations is None or name not in self .observations )
484
- and isinstance (var , (Constant , SharedVariable ))
485
- )
486
-
487
- # I don't know how to find pm.Data, except that they are named
488
- # variables that aren't observed or free RVs, nor are they
489
- # deterministics, and then we eliminate observations.
490
- for name , var in self .model .named_vars .items ():
491
- if is_data (name , var ):
492
- constant_data_vars [name ] = var
493
-
494
- if not constant_data_vars :
497
+ constant_data = find_constants (self .model )
498
+ if not constant_data :
495
499
return None
496
500
497
- constant_data = {}
498
- for name , vals in constant_data_vars .items ():
499
- if hasattr (vals , "get_value" ):
500
- vals = vals .get_value ()
501
- elif hasattr (vals , "data" ):
502
- vals = vals .data
503
- constant_data [name ] = vals
504
-
505
501
return dict_to_dataset (
506
502
constant_data ,
507
503
library = pymc ,
0 commit comments