43
43
Var = Any # pylint: disable=invalid-name
44
44
45
45
46
- def find_observations (model : Optional ["Model" ]) -> Optional [ Dict [str , Var ] ]:
46
+ def find_observations (model : Optional ["Model" ]) -> Dict [str , Var ]:
47
47
"""If there are observations available, return them as a dictionary."""
48
48
if model is None :
49
- return None
49
+ return {}
50
50
51
51
observations = {}
52
52
for obs in model .observed_RVs :
@@ -63,20 +63,37 @@ def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
63
63
return observations
64
64
65
65
66
- def find_constants (model : Optional ["Model" ]) -> Optional [ Dict [str , Var ] ]:
66
+ def find_constants (model : Optional ["Model" ]) -> Dict [str , Var ]:
67
67
"""If there are constants available, return them as a dictionary."""
68
- if model is None or not model .named_vars :
69
- return None
68
+ # The constant data vars must be either pm.Data or TensorConstant or SharedVariable
69
+ if model is None :
70
+ return {}
71
+
72
+ def is_data (name , var , model ) -> bool :
73
+ observations = find_observations (model )
74
+ return (
75
+ var not in model .deterministics
76
+ and var not in model .observed_RVs
77
+ and var not in model .free_RVs
78
+ and var not in model .potentials
79
+ and var not in model .value_vars
80
+ and name not in observations
81
+ and isinstance (var , (Constant , SharedVariable ))
82
+ )
70
83
71
- constants = {}
84
+ # The assumption is that constants (like pm.Data) are named
85
+ # variables that aren't observed or free RVs, nor are they
86
+ # deterministics, and then we eliminate observations.
87
+ constant_data = {}
72
88
for name , var in model .named_vars .items ():
73
- if isinstance (var , (Constant , SharedVariable )):
74
- if hasattr (var , "data" ):
75
- var = var .data
76
- elif hasattr (var , "get_value" ):
89
+ if is_data (name , var , model ):
90
+ if hasattr (var , "get_value" ):
77
91
var = var .get_value ()
78
- constants [name ] = var
79
- return constants
92
+ elif hasattr (var , "data" ):
93
+ var = var .data
94
+ constant_data [name ] = var
95
+
96
+ return constant_data
80
97
81
98
82
99
class _DefaultTrace :
@@ -483,41 +500,10 @@ def observed_data_to_xarray(self):
483
500
@requires ("model" )
484
501
def constant_data_to_xarray (self ):
485
502
"""Convert constant data to xarray."""
486
- # For constant data, we are concerned only with deterministics and
487
- # data. The constant data vars must be either pm.Data
488
- # (TensorConstant/SharedVariable) or pm.Deterministic
489
- constant_data_vars = {} # type: Dict[str, Var]
490
-
491
- def is_data (name , var ) -> bool :
492
- assert self .model is not None
493
- return (
494
- var not in self .model .deterministics
495
- and var not in self .model .observed_RVs
496
- and var not in self .model .free_RVs
497
- and var not in self .model .potentials
498
- and var not in self .model .value_vars
499
- and (self .observations is None or name not in self .observations )
500
- and isinstance (var , (Constant , SharedVariable ))
501
- )
502
-
503
- # I don't know how to find pm.Data, except that they are named
504
- # variables that aren't observed or free RVs, nor are they
505
- # deterministics, and then we eliminate observations.
506
- for name , var in self .model .named_vars .items ():
507
- if is_data (name , var ):
508
- constant_data_vars [name ] = var
509
-
510
- if not constant_data_vars :
503
+ constant_data = find_constants (self .model )
504
+ if not constant_data :
511
505
return None
512
506
513
- constant_data = {}
514
- for name , vals in constant_data_vars .items ():
515
- if hasattr (vals , "get_value" ):
516
- vals = vals .get_value ()
517
- elif hasattr (vals , "data" ):
518
- vals = vals .data
519
- constant_data [name ] = vals
520
-
521
507
return dict_to_dataset (
522
508
constant_data ,
523
509
library = pymc ,
0 commit comments