Skip to content

Commit 935ae4d

Browse files
MarcoGorellimichaelosthege
authored andcommitted
🏷️ fix some typing in _TraceDict
1 parent 8b1f64c commit 935ae4d

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

pymc3/distributions/posterior_predictive.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import contextvars
24
import logging
35
import numbers
@@ -50,21 +52,14 @@
5052
# test_mixture_random_shape::test_mixture_random_shape
5153
#
5254

53-
PosteriorPredictiveTrace = Dict[str, np.ndarray]
5455
Point = Dict[str, np.ndarray]
5556

5657

5758
class HasName(Protocol):
5859
name: str
5960

6061

61-
if TYPE_CHECKING:
62-
_TraceDictParent = UserDict[str, np.ndarray]
63-
else:
64-
_TraceDictParent = UserDict
65-
66-
67-
class _TraceDict(_TraceDictParent):
62+
class _TraceDict(UserDict):
6863
"""This class extends the standard trace-based representation
6964
of traces by adding some helpful attributes used in posterior predictive
7065
sampling.
@@ -75,24 +70,24 @@ class _TraceDict(_TraceDictParent):
7570

7671
varnames: List[str]
7772
_len: int
78-
data: Dict[str, np.ndarray]
73+
data: Point
7974

8075
def __init__(
8176
self,
82-
point_list: Optional[List[Dict[str, np.ndarray]]] = None,
77+
point_list: Optional[List[Point]] = None,
8378
multi_trace: Optional[MultiTrace] = None,
84-
dict: Optional[Dict[str, np.ndarray]] = None,
79+
dict_: Optional[Point] = None,
8580
):
8681
""""""
8782
if multi_trace:
88-
assert point_list is None and dict is None
89-
self.data = {} # Dict[str, np.ndarray]
83+
assert point_list is None and dict_ is None
84+
self.data = {}
9085
self._len = sum(len(multi_trace._straces[chain]) for chain in multi_trace.chains)
9186
self.varnames = multi_trace.varnames
9287
for vn in multi_trace.varnames:
9388
self.data[vn] = multi_trace.get_values(vn)
9489
if point_list is not None:
95-
assert multi_trace is None and dict is None
90+
assert multi_trace is None and dict_ is None
9691
self.varnames = varnames = list(point_list[0].keys())
9792
rep_values = [point_list[0][varname] for varname in varnames]
9893
# translate the point list.
@@ -114,18 +109,18 @@ def arr_for(val):
114109
for i, point in enumerate(point_list):
115110
for var, value in point.items():
116111
self.data[var][i] = value
117-
if dict is not None:
112+
if dict_ is not None:
118113
assert point_list is None and multi_trace is None
119-
self.data = dict
120-
self.varnames = list(dict.keys())
121-
self._len = dict[self.varnames[0]].shape[0]
114+
self.data = dict_
115+
self.varnames = list(dict_.keys())
116+
self._len = dict_[self.varnames[0]].shape[0]
122117
assert self.varnames is not None and self._len is not None and self.data is not None
123118

124119
def __len__(self) -> int:
125120
return self._len
126121

127-
def _extract_slice(self, slc: slice) -> "_TraceDict":
128-
sliced_dict: Dict[str, np.ndarray] = {}
122+
def _extract_slice(self, slc: slice) -> _TraceDict:
123+
sliced_dict: Point = {}
129124

130125
def apply_slice(arr: np.ndarray) -> np.ndarray:
131126
if len(arr.shape) == 1:
@@ -135,14 +130,14 @@ def apply_slice(arr: np.ndarray) -> np.ndarray:
135130

136131
for vn, arr in self.data.items():
137132
sliced_dict[vn] = apply_slice(arr)
138-
return _TraceDict(dict=sliced_dict)
133+
return _TraceDict(dict_=sliced_dict)
139134

140135
@overload
141136
def __getitem__(self, item: Union[str, HasName]) -> np.ndarray:
142137
...
143138

144139
@overload
145-
def __getitem__(self, item: Union[slice, int]) -> "_TraceDict":
140+
def __getitem__(self, item: Union[slice, int]) -> _TraceDict:
146141
...
147142

148143
def __getitem__(self, item):
@@ -151,7 +146,7 @@ def __getitem__(self, item):
151146
elif isinstance(item, slice):
152147
return self._extract_slice(item)
153148
elif isinstance(item, int):
154-
return _TraceDict(dict={k: np.atleast_1d(v[item]) for k, v in self.data.items()})
149+
return _TraceDict(dict_={k: np.atleast_1d(v[item]) for k, v in self.data.items()})
155150
elif hasattr(item, "name"):
156151
return super().__getitem__(item.name)
157152
else:
@@ -302,12 +297,10 @@ def extend_trace(self, trace: Dict[str, np.ndarray]) -> None:
302297
except KeyboardInterrupt:
303298
pass
304299

305-
if keep_size:
306-
return {
307-
k: ary.reshape((nchains, ndraws, *ary.shape[1:])) for k, ary in ppc_trace.items()
308-
}
309-
# this gets us a Dict[str, np.ndarray] instead of my wrapped equiv.
310-
return ppc_trace.data
300+
if keep_size:
301+
return {k: ary.reshape((nchains, ndraws, *ary.shape[1:])) for k, ary in ppc_trace.items()}
302+
# this gets us a Dict[str, np.ndarray] instead of my wrapped equiv.
303+
return ppc_trace.data
311304

312305

313306
def posterior_predictive_draw_values(

0 commit comments

Comments
 (0)