|
18 | 18 | Classes for working with subsets of parameters.
|
19 | 19 | """
|
20 | 20 | import collections
|
21 |
| -import copy |
22 | 21 |
|
23 |
| -import numpy as np |
| 22 | +from typing import Dict, List, Optional, Union |
24 | 23 |
|
25 |
| -from pymc3.util import get_var_name |
| 24 | +import numpy as np |
26 | 25 |
|
27 |
| -__all__ = ["ArrayOrdering", "DictToArrayBijection", "DictToVarBijection"] |
| 26 | +__all__ = ["ArrayOrdering", "DictToArrayBijection"] |
28 | 27 |
|
| 28 | +# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for |
| 29 | +# each of the raveled variables. |
| 30 | +RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info") |
29 | 31 | VarMap = collections.namedtuple("VarMap", "var, slc, shp, dtyp")
|
30 | 32 | DataMap = collections.namedtuple("DataMap", "list_ind, slc, shp, dtype, name")
|
31 | 33 |
|
32 | 34 |
|
33 |
| -# TODO Classes and methods need to be fully documented. |
34 |
| - |
35 |
| - |
36 | 35 | class ArrayOrdering:
|
37 | 36 | """
|
38 | 37 | An ordering for an array space
|
@@ -63,200 +62,67 @@ def __getitem__(self, key):
|
63 | 62 |
|
64 | 63 |
|
65 | 64 | class DictToArrayBijection:
|
66 |
| - """ |
67 |
| - A mapping between a dict space and an array space |
68 |
| - """ |
69 |
| - |
70 |
| - def __init__(self, ordering, dpoint): |
71 |
| - self.ordering = ordering |
72 |
| - self.dpt = dpoint |
| 65 | + """Map between a `dict`s of variables to an array space. |
73 | 66 |
|
74 |
| - # determine smallest float dtype that will fit all data |
75 |
| - if all([x.dtyp == "float16" for x in ordering.vmap]): |
76 |
| - self.array_dtype = "float16" |
77 |
| - elif all([x.dtyp == "float32" for x in ordering.vmap]): |
78 |
| - self.array_dtype = "float32" |
79 |
| - else: |
80 |
| - self.array_dtype = "float64" |
| 67 | + Said array space consists of all the vars raveled and then concatenated. |
81 | 68 |
|
82 |
| - def map(self, dpt): |
83 |
| - """ |
84 |
| - Maps value from dict space to array space |
| 69 | + """ |
85 | 70 |
|
86 |
| - Parameters |
87 |
| - ---------- |
88 |
| - dpt: dict |
89 |
| - """ |
90 |
| - apt = np.empty(self.ordering.size, dtype=self.array_dtype) |
91 |
| - for var, slc, _, _ in self.ordering.vmap: |
92 |
| - apt[slc] = dpt[var].ravel() |
93 |
| - return apt |
| 71 | + @staticmethod |
| 72 | + def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars: |
| 73 | + """Map a dictionary of names and variables to a concatenated 1D array space.""" |
| 74 | + vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items()) |
| 75 | + res = np.concatenate([v[0].ravel() for v in vars_info]) |
| 76 | + return RaveledVars(res, tuple(v[1:] for v in vars_info)) |
94 | 77 |
|
95 |
| - def rmap(self, apt): |
96 |
| - """ |
97 |
| - Maps value from array space to dict space |
| 78 | + @staticmethod |
| 79 | + def rmap( |
| 80 | + array: RaveledVars, as_list: Optional[bool] = False |
| 81 | + ) -> Union[Dict[str, np.ndarray], List[np.ndarray]]: |
| 82 | + """Map 1D concatenated array to a dictionary of variables in their original spaces. |
98 | 83 |
|
99 | 84 | Parameters
|
100 |
| - ---------- |
101 |
| - apt: array |
| 85 | + ========== |
| 86 | + array |
| 87 | + The array to map. |
| 88 | + as_list |
| 89 | + When ``True``, return a list of the original variables instead of a |
| 90 | + ``dict`` keyed each variable's name. |
102 | 91 | """
|
103 |
| - dpt = self.dpt.copy() |
| 92 | + if as_list: |
| 93 | + res = [] |
| 94 | + else: |
| 95 | + res = {} |
| 96 | + |
| 97 | + if not isinstance(array, RaveledVars): |
| 98 | + raise TypeError("`apt` must be a `RaveledVars` type") |
104 | 99 |
|
105 |
| - for var, slc, shp, dtyp in self.ordering.vmap: |
106 |
| - dpt[var] = np.atleast_1d(apt)[slc].reshape(shp).astype(dtyp) |
| 100 | + last_idx = 0 |
| 101 | + for name, shape, dtype in array.point_map_info: |
| 102 | + arr_len = np.prod(shape, dtype=int) |
| 103 | + var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype) |
| 104 | + if as_list: |
| 105 | + res.append(var) |
| 106 | + else: |
| 107 | + res[name] = var |
| 108 | + last_idx += arr_len |
107 | 109 |
|
108 |
| - return dpt |
| 110 | + return res |
109 | 111 |
|
110 |
| - def mapf(self, f): |
| 112 | + @classmethod |
| 113 | + def mapf(cls, f): |
111 | 114 | """
|
112 | 115 | function f: DictSpace -> T to ArraySpace -> T
|
113 | 116 |
|
114 | 117 | Parameters
|
115 | 118 | ----------
|
116 |
| -
|
117 | 119 | f: dict -> T
|
118 | 120 |
|
119 | 121 | Returns
|
120 | 122 | -------
|
121 | 123 | f: array -> T
|
122 | 124 | """
|
123 |
| - return Compose(f, self.rmap) |
124 |
| - |
125 |
| - |
126 |
| -class ListArrayOrdering: |
127 |
| - """ |
128 |
| - An ordering for a list to an array space. Takes also non aesara.tensors. |
129 |
| - Modified from pymc3 blocking. |
130 |
| -
|
131 |
| - Parameters |
132 |
| - ---------- |
133 |
| - list_arrays: list |
134 |
| - :class:`numpy.ndarray` or :class:`aesara.tensor.Tensor` |
135 |
| - intype: str |
136 |
| - defining the input type 'tensor' or 'numpy' |
137 |
| - """ |
138 |
| - |
139 |
| - def __init__(self, list_arrays, intype="numpy"): |
140 |
| - if intype not in {"tensor", "numpy"}: |
141 |
| - raise ValueError("intype not in {'tensor', 'numpy'}") |
142 |
| - self.vmap = [] |
143 |
| - self.intype = intype |
144 |
| - self.size = 0 |
145 |
| - for array in list_arrays: |
146 |
| - if self.intype == "tensor": |
147 |
| - name = array.name |
148 |
| - array = array.tag.test_value |
149 |
| - else: |
150 |
| - name = "numpy" |
151 |
| - |
152 |
| - slc = slice(self.size, self.size + array.size) |
153 |
| - self.vmap.append(DataMap(len(self.vmap), slc, array.shape, array.dtype, name)) |
154 |
| - self.size += array.size |
155 |
| - |
156 |
| - |
157 |
| -class ListToArrayBijection: |
158 |
| - """ |
159 |
| - A mapping between a List of arrays and an array space |
160 |
| -
|
161 |
| - Parameters |
162 |
| - ---------- |
163 |
| - ordering: :class:`ListArrayOrdering` |
164 |
| - list_arrays: list |
165 |
| - of :class:`numpy.ndarray` |
166 |
| - """ |
167 |
| - |
168 |
| - def __init__(self, ordering, list_arrays): |
169 |
| - self.ordering = ordering |
170 |
| - self.list_arrays = list_arrays |
171 |
| - |
172 |
| - def fmap(self, list_arrays): |
173 |
| - """ |
174 |
| - Maps values from List space to array space |
175 |
| -
|
176 |
| - Parameters |
177 |
| - ---------- |
178 |
| - list_arrays: list |
179 |
| - of :class:`numpy.ndarray` |
180 |
| -
|
181 |
| - Returns |
182 |
| - ------- |
183 |
| - array: :class:`numpy.ndarray` |
184 |
| - single array comprising all the input arrays |
185 |
| - """ |
186 |
| - |
187 |
| - array = np.empty(self.ordering.size) |
188 |
| - for list_ind, slc, _, _, _ in self.ordering.vmap: |
189 |
| - array[slc] = list_arrays[list_ind].ravel() |
190 |
| - return array |
191 |
| - |
192 |
| - def dmap(self, dpt): |
193 |
| - """ |
194 |
| - Maps values from dict space to List space |
195 |
| -
|
196 |
| - Parameters |
197 |
| - ---------- |
198 |
| - list_arrays: list |
199 |
| - of :class:`numpy.ndarray` |
200 |
| -
|
201 |
| - Returns |
202 |
| - ------- |
203 |
| - point |
204 |
| - """ |
205 |
| - a_list = copy.copy(self.list_arrays) |
206 |
| - |
207 |
| - for list_ind, _, _, _, var in self.ordering.vmap: |
208 |
| - a_list[list_ind] = dpt[var].ravel() |
209 |
| - |
210 |
| - return a_list |
211 |
| - |
212 |
| - def rmap(self, array): |
213 |
| - """ |
214 |
| - Maps value from array space to List space |
215 |
| - Inverse operation of fmap. |
216 |
| -
|
217 |
| - Parameters |
218 |
| - ---------- |
219 |
| - array: :class:`numpy.ndarray` |
220 |
| -
|
221 |
| - Returns |
222 |
| - ------- |
223 |
| - a_list: list |
224 |
| - of :class:`numpy.ndarray` |
225 |
| - """ |
226 |
| - |
227 |
| - a_list = copy.copy(self.list_arrays) |
228 |
| - |
229 |
| - for list_ind, slc, shp, dtype, _ in self.ordering.vmap: |
230 |
| - a_list[list_ind] = np.atleast_1d(array)[slc].reshape(shp).astype(dtype) |
231 |
| - |
232 |
| - return a_list |
233 |
| - |
234 |
| - |
235 |
| -class DictToVarBijection: |
236 |
| - """ |
237 |
| - A mapping between a dict space and the array space for one element within the dict space |
238 |
| - """ |
239 |
| - |
240 |
| - def __init__(self, var, idx, dpoint): |
241 |
| - self.var = get_var_name(var) |
242 |
| - self.idx = idx |
243 |
| - self.dpt = dpoint |
244 |
| - |
245 |
| - def map(self, dpt): |
246 |
| - return dpt[self.var][self.idx] |
247 |
| - |
248 |
| - def rmap(self, apt): |
249 |
| - dpt = self.dpt.copy() |
250 |
| - |
251 |
| - dvar = dpt[self.var].copy() |
252 |
| - dvar[self.idx] = apt |
253 |
| - |
254 |
| - dpt[self.var] = dvar |
255 |
| - |
256 |
| - return dpt |
257 |
| - |
258 |
| - def mapf(self, f): |
259 |
| - return Compose(f, self.rmap) |
| 125 | + return Compose(f, cls.rmap) |
260 | 126 |
|
261 | 127 |
|
262 | 128 | class Compose:
|
|
0 commit comments