Skip to content

Commit 3397585

Browse files
authored
REF: use idiomatic checks in __array_ufunc__ (#44765)
1 parent 82dcd07 commit 3397585

File tree

6 files changed

+53
-10
lines changed

6 files changed

+53
-10
lines changed

pandas/core/arraylike.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,16 @@ def array_ufunc(self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any)
326326
reconstruct_kwargs = {}
327327

328328
def reconstruct(result):
329+
if ufunc.nout > 1:
330+
# np.modf, np.frexp, np.divmod
331+
return tuple(_reconstruct(x) for x in result)
332+
333+
return _reconstruct(result)
334+
335+
def _reconstruct(result):
329336
if lib.is_scalar(result):
330337
return result
331338

332-
if isinstance(result, tuple):
333-
# np.modf, np.frexp, np.divmod
334-
return tuple(reconstruct(x) for x in result)
335-
336339
if result.ndim != self.ndim:
337340
if method == "outer":
338341
if self.ndim == 2:

pandas/core/arrays/masked.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ def reconstruct(x):
479479
return x
480480

481481
result = getattr(ufunc, method)(*inputs2, **kwargs)
482-
if isinstance(result, tuple):
482+
if ufunc.nout > 1:
483+
# e.g. np.divmod
483484
return tuple(reconstruct(x) for x in result)
484485
else:
485486
return reconstruct(result)

pandas/core/arrays/numpy_.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
152152
)
153153
result = getattr(ufunc, method)(*inputs, **kwargs)
154154

155-
if type(result) is tuple and len(result):
155+
if ufunc.nout > 1:
156156
# multiple return values
157157
if not lib.is_scalar(result[0]):
158158
# re-box array-like results
@@ -163,6 +163,13 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
163163
elif method == "at":
164164
# no return value
165165
return None
166+
elif method == "reduce":
167+
if isinstance(result, np.ndarray):
168+
# e.g. test_np_reduce_2d
169+
return type(self)(result)
170+
171+
# e.g. test_np_max_nested_tuples
172+
return result
166173
else:
167174
# one return value
168175
if not lib.is_scalar(result):

pandas/core/arrays/sparse/array.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
15801580
sp_values = getattr(ufunc, method)(self.sp_values, **kwargs)
15811581
fill_value = getattr(ufunc, method)(self.fill_value, **kwargs)
15821582

1583-
if isinstance(sp_values, tuple):
1583+
if ufunc.nout > 1:
15841584
# multiple outputs. e.g. modf
15851585
arrays = tuple(
15861586
self._simple_new(
@@ -1589,7 +1589,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
15891589
for sp_value, fv in zip(sp_values, fill_value)
15901590
)
15911591
return arrays
1592-
elif is_scalar(sp_values):
1592+
elif method == "reduce":
15931593
# e.g. reductions
15941594
return sp_values
15951595

@@ -1603,7 +1603,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
16031603
out = out[0]
16041604
return out
16051605

1606-
if type(result) is tuple:
1606+
if ufunc.nout > 1:
16071607
return tuple(type(self)(x) for x in result)
16081608
elif method == "at":
16091609
# no return value

pandas/tests/arrays/test_numpy.py

+32
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,38 @@ def test_validate_reduction_keyword_args():
194194
arr.all(keepdims=True)
195195

196196

197+
def test_np_max_nested_tuples():
198+
# case where checking in ufunc.nout works while checking for tuples
199+
# does not
200+
vals = [
201+
(("j", "k"), ("l", "m")),
202+
(("l", "m"), ("o", "p")),
203+
(("o", "p"), ("j", "k")),
204+
]
205+
ser = pd.Series(vals)
206+
arr = ser.array
207+
208+
assert arr.max() is arr[2]
209+
assert ser.max() is arr[2]
210+
211+
result = np.maximum.reduce(arr)
212+
assert result == arr[2]
213+
214+
result = np.maximum.reduce(ser)
215+
assert result == arr[2]
216+
217+
218+
def test_np_reduce_2d():
219+
raw = np.arange(12).reshape(4, 3)
220+
arr = PandasArray(raw)
221+
222+
res = np.maximum.reduce(arr, axis=0)
223+
tm.assert_extension_array_equal(res, arr[-1])
224+
225+
alt = arr.max(axis=0)
226+
tm.assert_extension_array_equal(alt, arr[-1])
227+
228+
197229
# ----------------------------------------------------------------------------
198230
# Ops
199231

pandas/tests/extension/decimal/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def reconstruct(x):
124124
else:
125125
return DecimalArray._from_sequence(x)
126126

127-
if isinstance(result, tuple):
127+
if ufunc.nout > 1:
128128
return tuple(reconstruct(x) for x in result)
129129
else:
130130
return reconstruct(result)

0 commit comments

Comments
 (0)