Skip to content

Commit bf084d2

Browse files
Remove shape dependencies from DictToArrayBijection
This commit changes `DictToArrayBijection` so that it returns a `RaveledVars` datatype that contains the original raveled and concatenated vector along with the information needed to revert it back to dictionay/variables form. Simply put, the variables-to-single-vector mapping steps have been pushed away from the model object and its symbolic terms and closer to the (sampling) processes that produce and work with `ndarray` values for said terms. In doing so, we can operate under fewer unnecessarily strong assumptions (e.g. that the shapes of each term are static and equal to the initial test points), and let the sampling processes that require vector-only steps deal with any changes in the mappings.
1 parent d8ac11e commit bf084d2

22 files changed

+270
-496
lines changed

pymc3/aesaraf.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from aesara.tensor.elemwise import Elemwise
2525
from aesara.tensor.var import TensorVariable
2626

27-
from pymc3.blocking import ArrayOrdering
2827
from pymc3.data import GeneratorAdapter
2928
from pymc3.vartypes import continuous_types, int_types, typefilter
3029

@@ -267,14 +266,16 @@ def join_nonshared_inputs(xs, vars, shared, make_shared=False):
267266
else:
268267
inarray = aesara.shared(joined.tag.test_value, "inarray")
269268

270-
ordering = ArrayOrdering(vars)
271269
inarray.tag.test_value = joined.tag.test_value
272270

273-
get_var = {var.name: var for var in vars}
274-
replace = {
275-
get_var[var]: reshape_t(inarray[slc], shp).astype(dtyp)
276-
for var, slc, shp, dtyp in ordering.vmap
277-
}
271+
replace = {}
272+
last_idx = 0
273+
for var in vars:
274+
arr_len = aet.prod(var.shape)
275+
replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], var.shape).astype(
276+
var.dtype
277+
)
278+
last_idx += arr_len
278279

279280
replace.update(shared)
280281

pymc3/blocking.py

Lines changed: 46 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818
Classes for working with subsets of parameters.
1919
"""
2020
import collections
21-
import copy
2221

23-
import numpy as np
22+
from typing import Dict, List, Optional, Union
2423

25-
from pymc3.util import get_var_name
24+
import numpy as np
2625

27-
__all__ = ["ArrayOrdering", "DictToArrayBijection", "DictToVarBijection"]
26+
__all__ = ["ArrayOrdering", "DictToArrayBijection"]
2827

28+
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
29+
# each of the raveled variables.
30+
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
2931
VarMap = collections.namedtuple("VarMap", "var, slc, shp, dtyp")
3032
DataMap = collections.namedtuple("DataMap", "list_ind, slc, shp, dtype, name")
3133

3234

33-
# TODO Classes and methods need to be fully documented.
34-
35-
3635
class ArrayOrdering:
3736
"""
3837
An ordering for an array space
@@ -63,200 +62,67 @@ def __getitem__(self, key):
6362

6463

6564
class DictToArrayBijection:
66-
"""
67-
A mapping between a dict space and an array space
68-
"""
69-
70-
def __init__(self, ordering, dpoint):
71-
self.ordering = ordering
72-
self.dpt = dpoint
65+
"""Map between a `dict`s of variables to an array space.
7366
74-
# determine smallest float dtype that will fit all data
75-
if all([x.dtyp == "float16" for x in ordering.vmap]):
76-
self.array_dtype = "float16"
77-
elif all([x.dtyp == "float32" for x in ordering.vmap]):
78-
self.array_dtype = "float32"
79-
else:
80-
self.array_dtype = "float64"
67+
Said array space consists of all the vars raveled and then concatenated.
8168
82-
def map(self, dpt):
83-
"""
84-
Maps value from dict space to array space
69+
"""
8570

86-
Parameters
87-
----------
88-
dpt: dict
89-
"""
90-
apt = np.empty(self.ordering.size, dtype=self.array_dtype)
91-
for var, slc, _, _ in self.ordering.vmap:
92-
apt[slc] = dpt[var].ravel()
93-
return apt
71+
@staticmethod
72+
def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars:
73+
"""Map a dictionary of names and variables to a concatenated 1D array space."""
74+
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
75+
res = np.concatenate([v[0].ravel() for v in vars_info])
76+
return RaveledVars(res, tuple(v[1:] for v in vars_info))
9477

95-
def rmap(self, apt):
96-
"""
97-
Maps value from array space to dict space
78+
@staticmethod
79+
def rmap(
80+
array: RaveledVars, as_list: Optional[bool] = False
81+
) -> Union[Dict[str, np.ndarray], List[np.ndarray]]:
82+
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
9883
9984
Parameters
100-
----------
101-
apt: array
85+
==========
86+
array
87+
The array to map.
88+
as_list
89+
When ``True``, return a list of the original variables instead of a
90+
``dict`` keyed each variable's name.
10291
"""
103-
dpt = self.dpt.copy()
92+
if as_list:
93+
res = []
94+
else:
95+
res = {}
96+
97+
if not isinstance(array, RaveledVars):
98+
raise TypeError("`apt` must be a `RaveledVars` type")
10499

105-
for var, slc, shp, dtyp in self.ordering.vmap:
106-
dpt[var] = np.atleast_1d(apt)[slc].reshape(shp).astype(dtyp)
100+
last_idx = 0
101+
for name, shape, dtype in array.point_map_info:
102+
arr_len = np.prod(shape, dtype=int)
103+
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
104+
if as_list:
105+
res.append(var)
106+
else:
107+
res[name] = var
108+
last_idx += arr_len
107109

108-
return dpt
110+
return res
109111

110-
def mapf(self, f):
112+
@classmethod
113+
def mapf(cls, f):
111114
"""
112115
function f: DictSpace -> T to ArraySpace -> T
113116
114117
Parameters
115118
----------
116-
117119
f: dict -> T
118120
119121
Returns
120122
-------
121123
f: array -> T
122124
"""
123-
return Compose(f, self.rmap)
124-
125-
126-
class ListArrayOrdering:
127-
"""
128-
An ordering for a list to an array space. Takes also non aesara.tensors.
129-
Modified from pymc3 blocking.
130-
131-
Parameters
132-
----------
133-
list_arrays: list
134-
:class:`numpy.ndarray` or :class:`aesara.tensor.Tensor`
135-
intype: str
136-
defining the input type 'tensor' or 'numpy'
137-
"""
138-
139-
def __init__(self, list_arrays, intype="numpy"):
140-
if intype not in {"tensor", "numpy"}:
141-
raise ValueError("intype not in {'tensor', 'numpy'}")
142-
self.vmap = []
143-
self.intype = intype
144-
self.size = 0
145-
for array in list_arrays:
146-
if self.intype == "tensor":
147-
name = array.name
148-
array = array.tag.test_value
149-
else:
150-
name = "numpy"
151-
152-
slc = slice(self.size, self.size + array.size)
153-
self.vmap.append(DataMap(len(self.vmap), slc, array.shape, array.dtype, name))
154-
self.size += array.size
155-
156-
157-
class ListToArrayBijection:
158-
"""
159-
A mapping between a List of arrays and an array space
160-
161-
Parameters
162-
----------
163-
ordering: :class:`ListArrayOrdering`
164-
list_arrays: list
165-
of :class:`numpy.ndarray`
166-
"""
167-
168-
def __init__(self, ordering, list_arrays):
169-
self.ordering = ordering
170-
self.list_arrays = list_arrays
171-
172-
def fmap(self, list_arrays):
173-
"""
174-
Maps values from List space to array space
175-
176-
Parameters
177-
----------
178-
list_arrays: list
179-
of :class:`numpy.ndarray`
180-
181-
Returns
182-
-------
183-
array: :class:`numpy.ndarray`
184-
single array comprising all the input arrays
185-
"""
186-
187-
array = np.empty(self.ordering.size)
188-
for list_ind, slc, _, _, _ in self.ordering.vmap:
189-
array[slc] = list_arrays[list_ind].ravel()
190-
return array
191-
192-
def dmap(self, dpt):
193-
"""
194-
Maps values from dict space to List space
195-
196-
Parameters
197-
----------
198-
list_arrays: list
199-
of :class:`numpy.ndarray`
200-
201-
Returns
202-
-------
203-
point
204-
"""
205-
a_list = copy.copy(self.list_arrays)
206-
207-
for list_ind, _, _, _, var in self.ordering.vmap:
208-
a_list[list_ind] = dpt[var].ravel()
209-
210-
return a_list
211-
212-
def rmap(self, array):
213-
"""
214-
Maps value from array space to List space
215-
Inverse operation of fmap.
216-
217-
Parameters
218-
----------
219-
array: :class:`numpy.ndarray`
220-
221-
Returns
222-
-------
223-
a_list: list
224-
of :class:`numpy.ndarray`
225-
"""
226-
227-
a_list = copy.copy(self.list_arrays)
228-
229-
for list_ind, slc, shp, dtype, _ in self.ordering.vmap:
230-
a_list[list_ind] = np.atleast_1d(array)[slc].reshape(shp).astype(dtype)
231-
232-
return a_list
233-
234-
235-
class DictToVarBijection:
236-
"""
237-
A mapping between a dict space and the array space for one element within the dict space
238-
"""
239-
240-
def __init__(self, var, idx, dpoint):
241-
self.var = get_var_name(var)
242-
self.idx = idx
243-
self.dpt = dpoint
244-
245-
def map(self, dpt):
246-
return dpt[self.var][self.idx]
247-
248-
def rmap(self, apt):
249-
dpt = self.dpt.copy()
250-
251-
dvar = dpt[self.var].copy()
252-
dvar[self.idx] = apt
253-
254-
dpt[self.var] = dvar
255-
256-
return dpt
257-
258-
def mapf(self, f):
259-
return Compose(f, self.rmap)
125+
return Compose(f, cls.rmap)
260126

261127

262128
class Compose:

pymc3/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,7 @@ def dist(cls, p, **kwargs):
13481348

13491349

13501350
@_logp.register(CategoricalRV)
1351-
def categorical_logp(op, value, p_, upper):
1351+
def categorical_logp(op, value, p, upper):
13521352
r"""
13531353
Calculate log-probability of Categorical distribution at specified value.
13541354

0 commit comments

Comments
 (0)