Skip to content

Commit cefb74c

Browse files
Backport PR pandas-dev#45451: Add back __array_wrap__ for dask compatibility (pandas-dev#45491)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent aaba0ef commit cefb74c

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

pandas/core/generic.py

+35
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,41 @@ def empty(self) -> bool_t:
20712071
def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray:
20722072
return np.asarray(self._values, dtype=dtype)
20732073

2074+
def __array_wrap__(
2075+
self,
2076+
result: np.ndarray,
2077+
context: tuple[Callable, tuple[Any, ...], int] | None = None,
2078+
):
2079+
"""
2080+
Gets called after a ufunc and other functions.
2081+
2082+
Parameters
2083+
----------
2084+
result: np.ndarray
2085+
The result of the ufunc or other function called on the NumPy array
2086+
returned by __array__
2087+
context: tuple of (func, tuple, int)
2088+
This parameter is returned by ufuncs as a 3-element tuple: (name of the
2089+
ufunc, arguments of the ufunc, domain of the ufunc), but is not set by
2090+
other numpy functions.q
2091+
2092+
Notes
2093+
-----
2094+
Series implements __array_ufunc_ so this not called for ufunc on Series.
2095+
"""
2096+
# Note: at time of dask 2022.01.0, this is still used by dask
2097+
res = lib.item_from_zerodim(result)
2098+
if is_scalar(res):
2099+
# e.g. we get here with np.ptp(series)
2100+
# ptp also requires the item_from_zerodim
2101+
return res
2102+
d = self._construct_axes_dict(self._AXIS_ORDERS, copy=False)
2103+
# error: Argument 1 to "NDFrame" has incompatible type "ndarray";
2104+
# expected "BlockManager"
2105+
return self._constructor(res, **d).__finalize__( # type: ignore[arg-type]
2106+
self, method="__array_wrap__"
2107+
)
2108+
20742109
@final
20752110
def __array_ufunc__(
20762111
self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any

pandas/tests/base/test_misc.py

+10
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def test_ndarray_compat_properties(index_or_series_obj):
8484
assert Series([1]).item() == 1
8585

8686

87+
def test_array_wrap_compat():
88+
# Note: at time of dask 2022.01.0, this is still used by eg dask
89+
# (https://github.com/dask/dask/issues/8580).
90+
# This test is a small dummy ensuring coverage
91+
orig = Series([1, 2, 3], dtype="int64", index=["a", "b", "c"])
92+
result = orig.__array_wrap__(np.array([2, 4, 6], dtype="int64"))
93+
expected = orig * 2
94+
tm.assert_series_equal(result, expected)
95+
96+
8797
@pytest.mark.skipif(PYPY, reason="not relevant for PyPy")
8898
def test_memory_usage(index_or_series_obj):
8999
obj = index_or_series_obj

pandas/tests/test_downstream.py

+24
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,30 @@ def test_dask(df):
5050
pd.set_option("compute.use_numexpr", olduse)
5151

5252

53+
@pytest.mark.filterwarnings("ignore:.*64Index is deprecated:FutureWarning")
54+
def test_dask_ufunc():
55+
# At the time of dask 2022.01.0, dask is still directly using __array_wrap__
56+
# for some ufuncs (https://github.com/dask/dask/issues/8580).
57+
58+
# dask sets "compute.use_numexpr" to False, so catch the current value
59+
# and ensure to reset it afterwards to avoid impacting other tests
60+
olduse = pd.get_option("compute.use_numexpr")
61+
62+
try:
63+
dask = import_module("dask") # noqa:F841
64+
import dask.array as da
65+
import dask.dataframe as dd
66+
67+
s = pd.Series([1.5, 2.3, 3.7, 4.0])
68+
ds = dd.from_pandas(s, npartitions=2)
69+
70+
result = da.fix(ds).compute()
71+
expected = np.fix(s)
72+
tm.assert_series_equal(result, expected)
73+
finally:
74+
pd.set_option("compute.use_numexpr", olduse)
75+
76+
5377
def test_xarray(df):
5478

5579
xarray = import_module("xarray") # noqa:F841

0 commit comments

Comments
 (0)