From 82bc8c9ba350d6df4dba3727406e7329a6fece90 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 9 May 2018 12:01:19 -0400 Subject: [PATCH] API: Added axis to take --- pandas/core/algorithms.py | 12 +++++++++--- pandas/tests/test_take.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index e8f74cf58a262..88bc497f9f22d 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1448,7 +1448,7 @@ def func(arr, indexer, out, fill_value=np.nan): return func -def take(arr, indices, allow_fill=False, fill_value=None): +def take(arr, indices, axis=0, allow_fill=False, fill_value=None): """ Take elements from an array. @@ -1461,6 +1461,8 @@ def take(arr, indices, allow_fill=False, fill_value=None): to an ndarray. indices : sequence of integers Indices to be taken. + axis : int, default 0 + The axis over which to select values. allow_fill : bool, default False How to handle negative values in `indices`. @@ -1476,6 +1478,9 @@ def take(arr, indices, allow_fill=False, fill_value=None): This may be ``None``, in which case the default NA value for the type (``self.dtype.na_value``) is used. + For multi-dimensional `arr`, each *element* is filled with + `fill_value`. + Returns ------- ndarray or ExtensionArray @@ -1529,10 +1534,11 @@ def take(arr, indices, allow_fill=False, fill_value=None): if allow_fill: # Pandas style, -1 means NA validate_indices(indices, len(arr)) - result = take_1d(arr, indices, allow_fill=True, fill_value=fill_value) + result = take_1d(arr, indices, axis=axis, allow_fill=True, + fill_value=fill_value) else: # NumPy style - result = arr.take(indices) + result = arr.take(indices, axis=axis) return result diff --git a/pandas/tests/test_take.py b/pandas/tests/test_take.py index 2b78c91f9dac5..9ab147edb8d1b 100644 --- a/pandas/tests/test_take.py +++ b/pandas/tests/test_take.py @@ -447,6 +447,29 @@ def test_2d_datetime64(self): expected[:, [2, 4]] = datetime(2007, 1, 1) tm.assert_almost_equal(result, expected) + def test_take_axis_0(self): + arr = np.arange(12).reshape(4, 3) + result = algos.take(arr, [0, -1]) + expected = np.array([[0, 1, 2], [9, 10, 11]]) + tm.assert_numpy_array_equal(result, expected) + + # allow_fill=True + result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0) + expected = np.array([[0, 1, 2], [0, 0, 0]]) + tm.assert_numpy_array_equal(result, expected) + + def test_take_axis_1(self): + arr = np.arange(12).reshape(4, 3) + result = algos.take(arr, [0, -1], axis=1) + expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]]) + tm.assert_numpy_array_equal(result, expected) + + # allow_fill=True + result = algos.take(arr, [0, -1], axis=1, allow_fill=True, + fill_value=0) + expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]]) + tm.assert_numpy_array_equal(result, expected) + class TestExtensionTake(object): # The take method found in pd.api.extensions