Skip to content

Commit 3862b4d

Browse files
authored
Add atop(concatenate=False) keyword argument (dask#1609)
This allows atop to pre-concatenate arrays before sending to the user defined function.
1 parent 34660a2 commit 3862b4d

File tree

3 files changed

+102
-12
lines changed

3 files changed

+102
-12
lines changed

dask/array/core.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,17 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
260260
261261
Applies a function, ``func``, across blocks from many different input
262262
dasks. We arrange the pattern with which those blocks interact with sets
263-
of matching indices. E.g.
263+
of matching indices. E.g.::
264264
265-
``top(func, 'z', 'i', 'x', 'i', 'y', 'i')``
265+
top(func, 'z', 'i', 'x', 'i', 'y', 'i')
266266
267267
yield an embarrassingly parallel communication pattern and is read as
268268
269269
$$ z_i = func(x_i, y_i) $$
270270
271-
More complex patterns may emerge, including multiple indices
271+
More complex patterns may emerge, including multiple indices::
272272
273-
``top(func, 'z', 'ij', 'x', 'ij', 'y', 'ji')``
273+
top(func, 'z', 'ij', 'x', 'ij', 'y', 'ji')
274274
275275
$$ z_{ij} = func(x_{ij}, y_{ji}) $$
276276
@@ -324,6 +324,15 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
324324
('z', 1, 1): (dotmany, [('x', 1, 0), ('x', 1, 1)],
325325
[('y', 0, 1), ('y', 1, 1)])}
326326
327+
Pass ``concatenate=True`` to concatenate arrays ahead of time
328+
329+
>>> top(f, 'z', 'i', 'x', 'ij', 'y', 'ij', concatenate=True,
330+
... numblocks={'x': (2, 2), 'y': (2, 2,)}) # doctest: +SKIP
331+
{('z', 0): (f, (concatenate_axes, [('x', 0, 0), ('x', 0, 1)], (1,)),
332+
(concatenate_axes, [('y', 0, 0), ('y', 0, 1)], (1,)))
333+
('z', 1): (f, (concatenate_axes, [('x', 1, 0), ('x', 1, 1)], (1,)),
334+
(concatenate_axes, [('y', 1, 0), ('y', 1, 1)], (1,)))}
335+
327336
Supports Broadcasting rules
328337
329338
>>> top(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (1, 2),
@@ -336,11 +345,16 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
336345
Support keyword arguments with apply
337346
338347
>>> def f(a, b=0): return a + b
339-
>>> top(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,), b=10}) # doctest: +SKIP
348+
>>> top(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,)}, b=10) # doctest: +SKIP
340349
{('z', 0): (apply, f, [('x', 0)], {'b': 10}),
341350
('z', 1): (apply, f, [('x', 1)], {'b': 10})}
351+
352+
See Also
353+
--------
354+
atop
342355
"""
343356
numblocks = kwargs.pop('numblocks')
357+
concatenate = kwargs.pop('concatenate', None)
344358
argpairs = list(partition(2, arrind_pairs))
345359

346360
assert set(numblocks) == set(pluck(0, argpairs))
@@ -366,6 +380,9 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
366380
for arg, ind in argpairs:
367381
tups = lol_tuples((arg,), ind, kd, dummies)
368382
tups2 = zero_broadcast_dimensions(tups, numblocks[arg])
383+
if concatenate and isinstance(tups2, list):
384+
axes = [n for n, i in enumerate(ind) if i in dummies]
385+
tups2 = (concatenate_axes, tups2, axes)
369386
args.append(tups2)
370387
valtups.append(tuple(args))
371388

@@ -1731,6 +1748,8 @@ def atop(func, out_ind, *args, **kwargs):
17311748
Function to apply to individual tuples of blocks
17321749
out_ind: iterable
17331750
Block pattern of the output, something like 'ijk' or (1, 2, 3)
1751+
concatenate: bool
1752+
If true concatenate arrays along dummy indices, else provide lists
17341753
*args: sequence of Array, index pairs
17351754
Sequence like (x, 'ij', y, 'jk', z, 'i')
17361755
**kwargs: dict
@@ -1767,8 +1786,9 @@ def atop(func, out_ind, *args, **kwargs):
17671786
Any index, like ``i`` missing from the output index is interpreted as a
17681787
contraction (note that this differs from Einstein convention; repeated
17691788
indices do not imply contraction.) In the case of a contraction the passed
1770-
function should expect an iterator of blocks on any array that holds that
1771-
index.
1789+
function should expect an iterable of blocks on any array that holds that
1790+
index. To receive arrays concatenated along contracted dimensions instead
1791+
pass ``concatenate=True``.
17721792
17731793
Inner product multiplying x by y, two 1-d vectors
17741794
@@ -2101,10 +2121,10 @@ def tensordot(lhs, rhs, axes=2):
21012121
out_index.remove(right_index[r])
21022122
right_index[r] = left_index[l]
21032123

2104-
func = partial(np.tensordot, axes=(left_axes, right_axes))
2105-
intermediate = atop(func, out_index,
2124+
intermediate = atop(np.tensordot, out_index,
21062125
lhs, left_index,
2107-
rhs, right_index, dtype=dt)
2126+
rhs, right_index, dtype=dt,
2127+
axes=(left_axes, right_axes))
21082128

21092129
int_index = list(out_index)
21102130
for l in left_axes:
@@ -3160,6 +3180,19 @@ def dtype(x):
31603180
return result
31613181

31623182

3183+
def concatenate_axes(arrays, axes):
3184+
""" Recurseively call np.concatenate along axes
3185+
3186+
TODO: This performs many copies. We should be able to do this in one
3187+
TODO: Merge logic on concatenate3 with this
3188+
"""
3189+
if len(axes) != ndimlist(arrays):
3190+
raise ValueError("Length of axes should equal depth of nested arrays")
3191+
if len(axes) > 1:
3192+
arrays = [concatenate_axes(a, axes[1:]) for a in arrays]
3193+
return np.concatenate(arrays, axis=axes[0])
3194+
3195+
31633196
def to_hdf5(filename, *args, **kwargs):
31643197
""" Store arrays in HDF5 file
31653198

dask/array/reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def reduction(x, chunk, aggregate, axis=None, keepdims=None, dtype=None,
3636

3737
# Map chunk across all blocks
3838
inds = tuple(range(x.ndim))
39-
tmp = atop(partial(chunk, axis=axis, keepdims=True), inds, x, inds)
39+
tmp = atop(chunk, inds, x, inds, axis=axis, keepdims=True)
4040
tmp._chunks = tuple((1, ) * len(c) if i in axis else c for (i, c)
4141
in enumerate(tmp.chunks))
4242

dask/array/tests/test_array_core.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
broadcast_to, reshape, fromfunction,
3030
blockdims_from_blockshape, store, optimize,
3131
from_func, normalize_chunks, broadcast_chunks,
32-
atop, from_delayed)
32+
atop, from_delayed, concatenate_axes)
3333
from dask.array.utils import assert_eq
3434

3535
# temporary until numpy functions migrated
@@ -2119,3 +2119,60 @@ def test_from_array_raises_on_bad_chunks():
21192119

21202120
with pytest.raises(ValueError):
21212121
da.from_array(x, chunks=((5, 5, 5),))
2122+
2123+
2124+
def test_concatenate_axes():
2125+
x = np.ones((2, 2, 2))
2126+
2127+
assert_eq(concatenate_axes([x, x], axes=[0]),
2128+
np.ones((4, 2, 2)))
2129+
assert_eq(concatenate_axes([x, x, x], axes=[0]),
2130+
np.ones((6, 2, 2)))
2131+
assert_eq(concatenate_axes([x, x], axes=[1]),
2132+
np.ones((2, 4, 2)))
2133+
assert_eq(concatenate_axes([[x, x], [x, x]], axes=[0, 1]),
2134+
np.ones((4, 4, 2)))
2135+
assert_eq(concatenate_axes([[x, x], [x, x]], axes=[0, 2]),
2136+
np.ones((4, 2, 4)))
2137+
assert_eq(concatenate_axes([[x, x, x], [x, x, x]], axes=[1, 2]),
2138+
np.ones((2, 4, 6)))
2139+
2140+
with pytest.raises(ValueError):
2141+
concatenate_axes([[x, x], [x, x]], axes=[0]) # not all nested lists accounted for
2142+
with pytest.raises(ValueError):
2143+
concatenate_axes([x, x], axes=[0, 1, 2, 3]) # too many axes
2144+
2145+
2146+
def test_atop_concatenate():
2147+
x = da.ones((4, 4, 4), chunks=(2, 2, 2))
2148+
y = da.ones((4, 4), chunks=(2, 2))
2149+
2150+
def f(a, b):
2151+
assert isinstance(a, np.ndarray)
2152+
assert isinstance(b, np.ndarray)
2153+
2154+
assert a.shape == (2, 4, 4)
2155+
assert b.shape == (4, 4)
2156+
2157+
return (a + b).sum(axis=(1, 2))
2158+
2159+
z = atop(f, 'i', x, 'ijk', y, 'jk', concatenate=True)
2160+
assert_eq(z, np.ones(4) * 32)
2161+
2162+
z = atop(add, 'ij', y, 'ij', y, 'ij', concatenate=True)
2163+
assert_eq(z, np.ones((4, 4)) * 2)
2164+
2165+
2166+
def f(a, b, c):
2167+
assert isinstance(a, np.ndarray)
2168+
assert isinstance(b, np.ndarray)
2169+
assert isinstance(c, np.ndarray)
2170+
2171+
assert a.shape == (4, 2, 4)
2172+
assert b.shape == (4, 4)
2173+
assert c.shape == (4, 2)
2174+
2175+
return np.ones(5)
2176+
2177+
z = atop(f, 'j', x, 'ijk', y, 'ki', y, 'ij', concatenate=True)
2178+
assert_eq(z, np.ones(10))

0 commit comments

Comments
 (0)