Skip to content

Commit 04f5721

Browse files
Add back __array_wrap__ for dask compatibility (#45451)
1 parent 3ac8543 commit 04f5721

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

pandas/core/generic.py

+35
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,41 @@ def empty(self) -> bool_t:
20522052
def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray:
20532053
return np.asarray(self._values, dtype=dtype)
20542054

2055+
def __array_wrap__(
2056+
self,
2057+
result: np.ndarray,
2058+
context: tuple[Callable, tuple[Any, ...], int] | None = None,
2059+
):
2060+
"""
2061+
Gets called after a ufunc and other functions.
2062+
2063+
Parameters
2064+
----------
2065+
result: np.ndarray
2066+
The result of the ufunc or other function called on the NumPy array
2067+
returned by __array__
2068+
context: tuple of (func, tuple, int)
2069+
This parameter is returned by ufuncs as a 3-element tuple: (name of the
2070+
ufunc, arguments of the ufunc, domain of the ufunc), but is not set by
2071+
other numpy functions.q
2072+
2073+
Notes
2074+
-----
2075+
Series implements __array_ufunc_ so this not called for ufunc on Series.
2076+
"""
2077+
# Note: at time of dask 2022.01.0, this is still used by dask
2078+
res = lib.item_from_zerodim(result)
2079+
if is_scalar(res):
2080+
# e.g. we get here with np.ptp(series)
2081+
# ptp also requires the item_from_zerodim
2082+
return res
2083+
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
2084+
# error: Argument 1 to "NDFrame" has incompatible type "ndarray";
2085+
# expected "BlockManager"
2086+
return self._constructor(res, **d).__finalize__( # type: ignore[arg-type]
2087+
self, method="__array_wrap__"
2088+
)
2089+
20552090
@final
20562091
def __array_ufunc__(
20572092
self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any

pandas/tests/base/test_misc.py

+10
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def test_ndarray_compat_properties(index_or_series_obj):
8484
assert Series([1]).item() == 1
8585

8686

87+
def test_array_wrap_compat():
88+
# Note: at time of dask 2022.01.0, this is still used by eg dask
89+
# (https://github.com/dask/dask/issues/8580).
90+
# This test is a small dummy ensuring coverage
91+
orig = Series([1, 2, 3], dtype="int64", index=["a", "b", "c"])
92+
result = orig.__array_wrap__(np.array([2, 4, 6], dtype="int64"))
93+
expected = orig * 2
94+
tm.assert_series_equal(result, expected)
95+
96+
8797
@pytest.mark.skipif(PYPY, reason="not relevant for PyPy")
8898
def test_memory_usage(index_or_series_obj):
8999
obj = index_or_series_obj

pandas/tests/test_downstream.py

+24
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@ def test_dask(df):
5555
pd.set_option("compute.use_numexpr", olduse)
5656

5757

58+
@pytest.mark.filterwarnings("ignore:.*64Index is deprecated:FutureWarning")
59+
def test_dask_ufunc():
60+
# At the time of dask 2022.01.0, dask is still directly using __array_wrap__
61+
# for some ufuncs (https://github.com/dask/dask/issues/8580).
62+
63+
# dask sets "compute.use_numexpr" to False, so catch the current value
64+
# and ensure to reset it afterwards to avoid impacting other tests
65+
olduse = pd.get_option("compute.use_numexpr")
66+
67+
try:
68+
dask = import_module("dask") # noqa:F841
69+
import dask.array as da
70+
import dask.dataframe as dd
71+
72+
s = pd.Series([1.5, 2.3, 3.7, 4.0])
73+
ds = dd.from_pandas(s, npartitions=2)
74+
75+
result = da.fix(ds).compute()
76+
expected = np.fix(s)
77+
tm.assert_series_equal(result, expected)
78+
finally:
79+
pd.set_option("compute.use_numexpr", olduse)
80+
81+
5882
def test_xarray(df):
5983

6084
xarray = import_module("xarray") # noqa:F841

0 commit comments

Comments
 (0)