Skip to content

Commit d0021b3

Browse files
TomAugspurgertp
authored and
tp
committed
API: Added axis to take (pandas-dev#20999)
1 parent 8fe2ec0 commit d0021b3

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

pandas/core/algorithms.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1448,7 +1448,7 @@ def func(arr, indexer, out, fill_value=np.nan):
14481448
return func
14491449

14501450

1451-
def take(arr, indices, allow_fill=False, fill_value=None):
1451+
def take(arr, indices, axis=0, allow_fill=False, fill_value=None):
14521452
"""
14531453
Take elements from an array.
14541454
@@ -1461,6 +1461,8 @@ def take(arr, indices, allow_fill=False, fill_value=None):
14611461
to an ndarray.
14621462
indices : sequence of integers
14631463
Indices to be taken.
1464+
axis : int, default 0
1465+
The axis over which to select values.
14641466
allow_fill : bool, default False
14651467
How to handle negative values in `indices`.
14661468
@@ -1476,6 +1478,9 @@ def take(arr, indices, allow_fill=False, fill_value=None):
14761478
This may be ``None``, in which case the default NA value for
14771479
the type (``self.dtype.na_value``) is used.
14781480
1481+
For multi-dimensional `arr`, each *element* is filled with
1482+
`fill_value`.
1483+
14791484
Returns
14801485
-------
14811486
ndarray or ExtensionArray
@@ -1529,10 +1534,11 @@ def take(arr, indices, allow_fill=False, fill_value=None):
15291534
if allow_fill:
15301535
# Pandas style, -1 means NA
15311536
validate_indices(indices, len(arr))
1532-
result = take_1d(arr, indices, allow_fill=True, fill_value=fill_value)
1537+
result = take_1d(arr, indices, axis=axis, allow_fill=True,
1538+
fill_value=fill_value)
15331539
else:
15341540
# NumPy style
1535-
result = arr.take(indices)
1541+
result = arr.take(indices, axis=axis)
15361542
return result
15371543

15381544

pandas/tests/test_take.py

+23
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,29 @@ def test_2d_datetime64(self):
447447
expected[:, [2, 4]] = datetime(2007, 1, 1)
448448
tm.assert_almost_equal(result, expected)
449449

450+
def test_take_axis_0(self):
451+
arr = np.arange(12).reshape(4, 3)
452+
result = algos.take(arr, [0, -1])
453+
expected = np.array([[0, 1, 2], [9, 10, 11]])
454+
tm.assert_numpy_array_equal(result, expected)
455+
456+
# allow_fill=True
457+
result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0)
458+
expected = np.array([[0, 1, 2], [0, 0, 0]])
459+
tm.assert_numpy_array_equal(result, expected)
460+
461+
def test_take_axis_1(self):
462+
arr = np.arange(12).reshape(4, 3)
463+
result = algos.take(arr, [0, -1], axis=1)
464+
expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]])
465+
tm.assert_numpy_array_equal(result, expected)
466+
467+
# allow_fill=True
468+
result = algos.take(arr, [0, -1], axis=1, allow_fill=True,
469+
fill_value=0)
470+
expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]])
471+
tm.assert_numpy_array_equal(result, expected)
472+
450473

451474
class TestExtensionTake(object):
452475
# The take method found in pd.api.extensions

0 commit comments

Comments
 (0)