Skip to content

Commit 812d985

Browse files
committed
Cache size in DictToArrayBijection
1 parent acf5175 commit 812d985

File tree

4 files changed

+23
-31
lines changed

4 files changed

+23
-31
lines changed

pymc/blocking.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939
StatShape: TypeAlias = Sequence[int | None] | None
4040

4141

42-
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
42+
# `point_map_info` is a tuple of tuples containing `(name, shape, size, dtype)` for
4343
# each of the raveled variables.
4444
class RaveledVars(NamedTuple):
4545
data: np.ndarray
46-
point_map_info: tuple[tuple[str, tuple[int, ...], np.dtype], ...]
46+
point_map_info: tuple[tuple[str, tuple[int, ...], int, np.dtype], ...]
4747

4848

4949
class Compose(Generic[T]):
@@ -67,10 +67,9 @@ class DictToArrayBijection:
6767
@staticmethod
6868
def map(var_dict: PointType) -> RaveledVars:
6969
"""Map a dictionary of names and variables to a concatenated 1D array space."""
70-
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
71-
raveled_vars = [v[0].ravel() for v in vars_info]
72-
if raveled_vars:
73-
result = np.concatenate(raveled_vars)
70+
vars_info = tuple((v, k, v.shape, v.size, v.dtype) for k, v in var_dict.items())
71+
if vars_info:
72+
result = np.concatenate(tuple(v[0].ravel() for v in vars_info))
7473
else:
7574
result = np.array([])
7675
return RaveledVars(result, tuple(v[1:] for v in vars_info))
@@ -91,19 +90,15 @@ def rmap(
9190
9291
"""
9392
if start_point:
94-
result = dict(start_point)
93+
result = start_point.copy()
9594
else:
9695
result = {}
9796

98-
if not isinstance(array, RaveledVars):
99-
raise TypeError("`array` must be a `RaveledVars` type")
100-
10197
last_idx = 0
102-
for name, shape, dtype in array.point_map_info:
103-
arr_len = np.prod(shape, dtype=int)
104-
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
105-
result[name] = var
106-
last_idx += arr_len
98+
for name, shape, size, dtype in array.point_map_info:
99+
end = last_idx + size
100+
result[name] = array.data[last_idx:end].reshape(shape).astype(dtype)
101+
last_idx = end
107102

108103
return result
109104

pymc/sampling/parallel.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,12 @@ def __init__(
228228
self._shared_point = {}
229229
self._point = {}
230230

231-
for name, shape, dtype in DictToArrayBijection.map(start).point_map_info:
232-
size = 1
233-
for dim in shape:
234-
size *= int(dim)
235-
size *= dtype.itemsize
236-
if size != ctypes.c_size_t(size).value:
231+
for name, shape, size, dtype in DictToArrayBijection.map(start).point_map_info:
232+
byte_size = size * dtype.itemsize
233+
if byte_size != ctypes.c_size_t(byte_size).value:
237234
raise ValueError(f"Variable {name} is too large")
238235

239-
array = mp_ctx.RawArray("c", size)
236+
array = mp_ctx.RawArray("c", byte_size)
240237
self._shared_point[name] = (array, shape, dtype)
241238
array_np = np.frombuffer(array, dtype).reshape(shape)
242239
array_np[...] = start[name]

pymc/step_methods/hmc/quadpotential.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -363,23 +363,23 @@ def raise_ok(self, map_info):
363363
if np.any(self._stds == 0):
364364
errmsg = ["Mass matrix contains zeros on the diagonal. "]
365365
last_idx = 0
366-
for name, shape, dtype in map_info:
367-
arr_len = np.prod(shape, dtype=int)
368-
index = np.where(self._stds[last_idx : last_idx + arr_len] == 0)[0]
366+
for name, shape, size, dtype in map_info:
367+
end = last_idx + size
368+
index = np.where(self._stds[last_idx:end] == 0)[0]
369369
errmsg.append(f"The derivative of RV `{name}`.ravel()[{index}] is zero.")
370-
last_idx += arr_len
370+
last_idx += end
371371

372372
raise ValueError("\n".join(errmsg))
373373

374374
if np.any(~np.isfinite(self._stds)):
375375
errmsg = ["Mass matrix contains non-finite values on the diagonal. "]
376376

377377
last_idx = 0
378-
for name, shape, dtype in map_info:
379-
arr_len = np.prod(shape, dtype=int)
380-
index = np.where(~np.isfinite(self._stds[last_idx : last_idx + arr_len]))[0]
378+
for name, shape, size, dtype in map_info:
379+
end = last_idx + size
380+
index = np.where(~np.isfinite(self._stds[last_idx:end]))[0]
381381
errmsg.append(f"The derivative of RV `{name}`.ravel()[{index}] is non-finite.")
382-
last_idx += arr_len
382+
last_idx = end
383383
raise ValueError("\n".join(errmsg))
384384

385385

pymc/tuning/starting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def find_MAP(
143143
compiled_logp_func = DictToArrayBijection.mapf(model.compile_logp(jacobian=False), start)
144144
logp_func = lambda x: compiled_logp_func(RaveledVars(x, x0.point_map_info)) # noqa: E731
145145

146-
rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _ in x0.point_map_info]
146+
rvs = [model.values_to_rvs[vars_dict[name]] for name, _, _, _ in x0.point_map_info]
147147
try:
148148
# This might be needed for calls to `dlogp_func`
149149
# start_map_info = tuple((v.name, v.shape, v.dtype) for v in vars)

0 commit comments

Comments
 (0)