@@ -260,17 +260,17 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
260
260
261
261
Applies a function, ``func``, across blocks from many different input
262
262
dasks. We arrange the pattern with which those blocks interact with sets
263
- of matching indices. E.g.
263
+ of matching indices. E.g.::
264
264
265
- `` top(func, 'z', 'i', 'x', 'i', 'y', 'i')``
265
+ top(func, 'z', 'i', 'x', 'i', 'y', 'i')
266
266
267
267
yield an embarrassingly parallel communication pattern and is read as
268
268
269
269
$$ z_i = func(x_i, y_i) $$
270
270
271
- More complex patterns may emerge, including multiple indices
271
+ More complex patterns may emerge, including multiple indices::
272
272
273
- `` top(func, 'z', 'ij', 'x', 'ij', 'y', 'ji')``
273
+ top(func, 'z', 'ij', 'x', 'ij', 'y', 'ji')
274
274
275
275
$$ z_{ij} = func(x_{ij}, y_{ji}) $$
276
276
@@ -324,6 +324,15 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
324
324
('z', 1, 1): (dotmany, [('x', 1, 0), ('x', 1, 1)],
325
325
[('y', 0, 1), ('y', 1, 1)])}
326
326
327
+ Pass ``concatenate=True`` to concatenate arrays ahead of time
328
+
329
+ >>> top(f, 'z', 'i', 'x', 'ij', 'y', 'ij', concatenate=True,
330
+ ... numblocks={'x': (2, 2), 'y': (2, 2,)}) # doctest: +SKIP
331
+ {('z', 0): (f, (concatenate_axes, [('x', 0, 0), ('x', 0, 1)], (1,)),
332
+ (concatenate_axes, [('y', 0, 0), ('y', 0, 1)], (1,)))
333
+ ('z', 1): (f, (concatenate_axes, [('x', 1, 0), ('x', 1, 1)], (1,)),
334
+ (concatenate_axes, [('y', 1, 0), ('y', 1, 1)], (1,)))}
335
+
327
336
Supports Broadcasting rules
328
337
329
338
>>> top(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (1, 2),
@@ -336,11 +345,16 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
336
345
Support keyword arguments with apply
337
346
338
347
>>> def f(a, b=0): return a + b
339
- >>> top(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,), b=10} ) # doctest: +SKIP
348
+ >>> top(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,)} , b=10) # doctest: +SKIP
340
349
{('z', 0): (apply, f, [('x', 0)], {'b': 10}),
341
350
('z', 1): (apply, f, [('x', 1)], {'b': 10})}
351
+
352
+ See Also
353
+ --------
354
+ atop
342
355
"""
343
356
numblocks = kwargs .pop ('numblocks' )
357
+ concatenate = kwargs .pop ('concatenate' , None )
344
358
argpairs = list (partition (2 , arrind_pairs ))
345
359
346
360
assert set (numblocks ) == set (pluck (0 , argpairs ))
@@ -366,6 +380,9 @@ def top(func, output, out_indices, *arrind_pairs, **kwargs):
366
380
for arg , ind in argpairs :
367
381
tups = lol_tuples ((arg ,), ind , kd , dummies )
368
382
tups2 = zero_broadcast_dimensions (tups , numblocks [arg ])
383
+ if concatenate and isinstance (tups2 , list ):
384
+ axes = [n for n , i in enumerate (ind ) if i in dummies ]
385
+ tups2 = (concatenate_axes , tups2 , axes )
369
386
args .append (tups2 )
370
387
valtups .append (tuple (args ))
371
388
@@ -1731,6 +1748,8 @@ def atop(func, out_ind, *args, **kwargs):
1731
1748
Function to apply to individual tuples of blocks
1732
1749
out_ind: iterable
1733
1750
Block pattern of the output, something like 'ijk' or (1, 2, 3)
1751
+ concatenate: bool
1752
+ If true concatenate arrays along dummy indices, else provide lists
1734
1753
*args: sequence of Array, index pairs
1735
1754
Sequence like (x, 'ij', y, 'jk', z, 'i')
1736
1755
**kwargs: dict
@@ -1767,8 +1786,9 @@ def atop(func, out_ind, *args, **kwargs):
1767
1786
Any index, like ``i`` missing from the output index is interpreted as a
1768
1787
contraction (note that this differs from Einstein convention; repeated
1769
1788
indices do not imply contraction.) In the case of a contraction the passed
1770
- function should expect an iterator of blocks on any array that holds that
1771
- index.
1789
+ function should expect an iterable of blocks on any array that holds that
1790
+ index. To receive arrays concatenated along contracted dimensions instead
1791
+ pass ``concatenate=True``.
1772
1792
1773
1793
Inner product multiplying x by y, two 1-d vectors
1774
1794
@@ -2101,10 +2121,10 @@ def tensordot(lhs, rhs, axes=2):
2101
2121
out_index .remove (right_index [r ])
2102
2122
right_index [r ] = left_index [l ]
2103
2123
2104
- func = partial (np .tensordot , axes = (left_axes , right_axes ))
2105
- intermediate = atop (func , out_index ,
2124
+ intermediate = atop (np .tensordot , out_index ,
2106
2125
lhs , left_index ,
2107
- rhs , right_index , dtype = dt )
2126
+ rhs , right_index , dtype = dt ,
2127
+ axes = (left_axes , right_axes ))
2108
2128
2109
2129
int_index = list (out_index )
2110
2130
for l in left_axes :
@@ -3160,6 +3180,19 @@ def dtype(x):
3160
3180
return result
3161
3181
3162
3182
3183
+ def concatenate_axes (arrays , axes ):
3184
+ """ Recurseively call np.concatenate along axes
3185
+
3186
+ TODO: This performs many copies. We should be able to do this in one
3187
+ TODO: Merge logic on concatenate3 with this
3188
+ """
3189
+ if len (axes ) != ndimlist (arrays ):
3190
+ raise ValueError ("Length of axes should equal depth of nested arrays" )
3191
+ if len (axes ) > 1 :
3192
+ arrays = [concatenate_axes (a , axes [1 :]) for a in arrays ]
3193
+ return np .concatenate (arrays , axis = axes [0 ])
3194
+
3195
+
3163
3196
def to_hdf5 (filename , * args , ** kwargs ):
3164
3197
""" Store arrays in HDF5 file
3165
3198
0 commit comments