From 32870a05aae29a0383019316e83fdebf73d558b5 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 4 Dec 2021 21:32:05 -0800 Subject: [PATCH] REF: use idiomatic checks in __array_ufunc__ --- pandas/core/arraylike.py | 11 +++++---- pandas/core/arrays/masked.py | 3 ++- pandas/core/arrays/numpy_.py | 9 ++++++- pandas/core/arrays/sparse/array.py | 6 ++--- pandas/tests/arrays/test_numpy.py | 32 +++++++++++++++++++++++++ pandas/tests/extension/decimal/array.py | 2 +- 6 files changed, 53 insertions(+), 10 deletions(-) diff --git a/pandas/core/arraylike.py b/pandas/core/arraylike.py index c496099e3a8d2..1df999b1dbcce 100644 --- a/pandas/core/arraylike.py +++ b/pandas/core/arraylike.py @@ -326,13 +326,16 @@ def array_ufunc(self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any) reconstruct_kwargs = {} def reconstruct(result): + if ufunc.nout > 1: + # np.modf, np.frexp, np.divmod + return tuple(_reconstruct(x) for x in result) + + return _reconstruct(result) + + def _reconstruct(result): if lib.is_scalar(result): return result - if isinstance(result, tuple): - # np.modf, np.frexp, np.divmod - return tuple(reconstruct(x) for x in result) - if result.ndim != self.ndim: if method == "outer": if self.ndim == 2: diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index a882fe5d2da21..f03a740668540 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -476,7 +476,8 @@ def reconstruct(x): return x result = getattr(ufunc, method)(*inputs2, **kwargs) - if isinstance(result, tuple): + if ufunc.nout > 1: + # e.g. np.divmod return tuple(reconstruct(x) for x in result) else: return reconstruct(result) diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 0afe204b35c68..01e9a6c2e5399 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -163,7 +163,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): ) result = getattr(ufunc, method)(*inputs, **kwargs) - if type(result) is tuple and len(result): + if ufunc.nout > 1: # multiple return values if not lib.is_scalar(result[0]): # re-box array-like results @@ -174,6 +174,13 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): elif method == "at": # no return value return None + elif method == "reduce": + if isinstance(result, np.ndarray): + # e.g. test_np_reduce_2d + return type(self)(result) + + # e.g. test_np_max_nested_tuples + return result else: # one return value if not lib.is_scalar(result): diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index ffd25c5704211..ab6be9ee9d63d 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -1579,7 +1579,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): sp_values = getattr(ufunc, method)(self.sp_values, **kwargs) fill_value = getattr(ufunc, method)(self.fill_value, **kwargs) - if isinstance(sp_values, tuple): + if ufunc.nout > 1: # multiple outputs. e.g. modf arrays = tuple( self._simple_new( @@ -1588,7 +1588,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): for sp_value, fv in zip(sp_values, fill_value) ) return arrays - elif is_scalar(sp_values): + elif method == "reduce": # e.g. reductions return sp_values @@ -1602,7 +1602,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): out = out[0] return out - if type(result) is tuple: + if ufunc.nout > 1: return tuple(type(self)(x) for x in result) elif method == "at": # no return value diff --git a/pandas/tests/arrays/test_numpy.py b/pandas/tests/arrays/test_numpy.py index 9b9945495a733..66f7bf1f4d743 100644 --- a/pandas/tests/arrays/test_numpy.py +++ b/pandas/tests/arrays/test_numpy.py @@ -194,6 +194,38 @@ def test_validate_reduction_keyword_args(): arr.all(keepdims=True) +def test_np_max_nested_tuples(): + # case where checking in ufunc.nout works while checking for tuples + # does not + vals = [ + (("j", "k"), ("l", "m")), + (("l", "m"), ("o", "p")), + (("o", "p"), ("j", "k")), + ] + ser = pd.Series(vals) + arr = ser.array + + assert arr.max() is arr[2] + assert ser.max() is arr[2] + + result = np.maximum.reduce(arr) + assert result == arr[2] + + result = np.maximum.reduce(ser) + assert result == arr[2] + + +def test_np_reduce_2d(): + raw = np.arange(12).reshape(4, 3) + arr = PandasArray(raw) + + res = np.maximum.reduce(arr, axis=0) + tm.assert_extension_array_equal(res, arr[-1]) + + alt = arr.max(axis=0) + tm.assert_extension_array_equal(alt, arr[-1]) + + # ---------------------------------------------------------------------------- # Ops diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index fe7ebe4f4fb51..062ab9bc2b4d7 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -124,7 +124,7 @@ def reconstruct(x): else: return DecimalArray._from_sequence(x) - if isinstance(result, tuple): + if ufunc.nout > 1: return tuple(reconstruct(x) for x in result) else: return reconstruct(result)