1
+ from __future__ import annotations
2
+
1
3
import contextvars
2
4
import logging
3
5
import numbers
50
52
# test_mixture_random_shape::test_mixture_random_shape
51
53
#
52
54
53
- PosteriorPredictiveTrace = Dict [str , np .ndarray ]
54
55
Point = Dict [str , np .ndarray ]
55
56
56
57
57
58
class HasName (Protocol ):
58
59
name : str
59
60
60
61
61
- if TYPE_CHECKING :
62
- _TraceDictParent = UserDict [str , np .ndarray ]
63
- else :
64
- _TraceDictParent = UserDict
65
-
66
-
67
- class _TraceDict (_TraceDictParent ):
62
+ class _TraceDict (UserDict ):
68
63
"""This class extends the standard trace-based representation
69
64
of traces by adding some helpful attributes used in posterior predictive
70
65
sampling.
@@ -75,24 +70,24 @@ class _TraceDict(_TraceDictParent):
75
70
76
71
varnames : List [str ]
77
72
_len : int
78
- data : Dict [ str , np . ndarray ]
73
+ data : Point
79
74
80
75
def __init__ (
81
76
self ,
82
- point_list : Optional [List [Dict [ str , np . ndarray ] ]] = None ,
77
+ point_list : Optional [List [Point ]] = None ,
83
78
multi_trace : Optional [MultiTrace ] = None ,
84
- dict : Optional [Dict [ str , np . ndarray ] ] = None ,
79
+ dict_ : Optional [Point ] = None ,
85
80
):
86
81
""""""
87
82
if multi_trace :
88
- assert point_list is None and dict is None
89
- self .data = {} # Dict[str, np.ndarray]
83
+ assert point_list is None and dict_ is None
84
+ self .data = {}
90
85
self ._len = sum (len (multi_trace ._straces [chain ]) for chain in multi_trace .chains )
91
86
self .varnames = multi_trace .varnames
92
87
for vn in multi_trace .varnames :
93
88
self .data [vn ] = multi_trace .get_values (vn )
94
89
if point_list is not None :
95
- assert multi_trace is None and dict is None
90
+ assert multi_trace is None and dict_ is None
96
91
self .varnames = varnames = list (point_list [0 ].keys ())
97
92
rep_values = [point_list [0 ][varname ] for varname in varnames ]
98
93
# translate the point list.
@@ -114,18 +109,18 @@ def arr_for(val):
114
109
for i , point in enumerate (point_list ):
115
110
for var , value in point .items ():
116
111
self .data [var ][i ] = value
117
- if dict is not None :
112
+ if dict_ is not None :
118
113
assert point_list is None and multi_trace is None
119
- self .data = dict
120
- self .varnames = list (dict .keys ())
121
- self ._len = dict [self .varnames [0 ]].shape [0 ]
114
+ self .data = dict_
115
+ self .varnames = list (dict_ .keys ())
116
+ self ._len = dict_ [self .varnames [0 ]].shape [0 ]
122
117
assert self .varnames is not None and self ._len is not None and self .data is not None
123
118
124
119
def __len__ (self ) -> int :
125
120
return self ._len
126
121
127
- def _extract_slice (self , slc : slice ) -> " _TraceDict" :
128
- sliced_dict : Dict [ str , np . ndarray ] = {}
122
+ def _extract_slice (self , slc : slice ) -> _TraceDict :
123
+ sliced_dict : Point = {}
129
124
130
125
def apply_slice (arr : np .ndarray ) -> np .ndarray :
131
126
if len (arr .shape ) == 1 :
@@ -135,14 +130,14 @@ def apply_slice(arr: np.ndarray) -> np.ndarray:
135
130
136
131
for vn , arr in self .data .items ():
137
132
sliced_dict [vn ] = apply_slice (arr )
138
- return _TraceDict (dict = sliced_dict )
133
+ return _TraceDict (dict_ = sliced_dict )
139
134
140
135
@overload
141
136
def __getitem__ (self , item : Union [str , HasName ]) -> np .ndarray :
142
137
...
143
138
144
139
@overload
145
- def __getitem__ (self , item : Union [slice , int ]) -> " _TraceDict" :
140
+ def __getitem__ (self , item : Union [slice , int ]) -> _TraceDict :
146
141
...
147
142
148
143
def __getitem__ (self , item ):
@@ -151,7 +146,7 @@ def __getitem__(self, item):
151
146
elif isinstance (item , slice ):
152
147
return self ._extract_slice (item )
153
148
elif isinstance (item , int ):
154
- return _TraceDict (dict = {k : np .atleast_1d (v [item ]) for k , v in self .data .items ()})
149
+ return _TraceDict (dict_ = {k : np .atleast_1d (v [item ]) for k , v in self .data .items ()})
155
150
elif hasattr (item , "name" ):
156
151
return super ().__getitem__ (item .name )
157
152
else :
@@ -302,12 +297,10 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
302
297
except KeyboardInterrupt :
303
298
pass
304
299
305
- if keep_size :
306
- return {
307
- k : ary .reshape ((nchains , ndraws , * ary .shape [1 :])) for k , ary in ppc_trace .items ()
308
- }
309
- # this gets us a Dict[str, np.ndarray] instead of my wrapped equiv.
310
- return ppc_trace .data
300
+ if keep_size :
301
+ return {k : ary .reshape ((nchains , ndraws , * ary .shape [1 :])) for k , ary in ppc_trace .items ()}
302
+ # this gets us a Dict[str, np.ndarray] instead of my wrapped equiv.
303
+ return ppc_trace .data
311
304
312
305
313
306
def posterior_predictive_draw_values (
0 commit comments