Skip to content

Commit 5fe8d7d

Browse files
jbrockmendelrhshadrach
authored andcommitted
ENH: implement Index.__array_ufunc__ (pandas-dev#43904)
1 parent 2688ca8 commit 5fe8d7d

File tree

7 files changed

+44
-14
lines changed

7 files changed

+44
-14
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ Datetimelike
390390
- Bug in :func:`to_datetime` with ``format`` and ``pandas.NA`` was raising ``ValueError`` (:issue:`42957`)
391391
- :func:`to_datetime` would silently swap ``MM/DD/YYYY`` and ``DD/MM/YYYY`` formats if the given ``dayfirst`` option could not be respected - now, a warning is raised in the case of delimited date strings (e.g. ``31-12-2012``) (:issue:`12585`)
392392
- Bug in :meth:`date_range` and :meth:`bdate_range` do not return right bound when ``start`` = ``end`` and set is closed on one side (:issue:`43394`)
393+
- Bug in inplace addition and subtraction of :class:`DatetimeIndex` or :class:`TimedeltaIndex` with :class:`DatetimeArray` or :class:`TimedeltaArray` (:issue:`43904`)
393394
-
394395

395396
Timedelta

pandas/core/arraylike.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def reconstruct(result):
357357
return result
358358

359359
if "out" in kwargs:
360-
result = _dispatch_ufunc_with_out(self, ufunc, method, *inputs, **kwargs)
360+
result = dispatch_ufunc_with_out(self, ufunc, method, *inputs, **kwargs)
361361
return reconstruct(result)
362362

363363
# We still get here with kwargs `axis` for e.g. np.maximum.accumulate
@@ -410,7 +410,7 @@ def _standardize_out_kwarg(**kwargs) -> dict:
410410
return kwargs
411411

412412

413-
def _dispatch_ufunc_with_out(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
413+
def dispatch_ufunc_with_out(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
414414
"""
415415
If we have an `out` keyword, then call the ufunc without `out` and then
416416
set the result into the given `out`.

pandas/core/arrays/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,11 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
13791379
if result is not NotImplemented:
13801380
return result
13811381

1382+
if "out" in kwargs:
1383+
return arraylike.dispatch_ufunc_with_out(
1384+
self, ufunc, method, *inputs, **kwargs
1385+
)
1386+
13821387
return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)
13831388

13841389

pandas/core/arrays/datetimelike.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,7 @@ def __iadd__(self, other):
14141414

14151415
if not is_period_dtype(self.dtype):
14161416
# restore freq, which is invalidated by setitem
1417-
self._freq = result._freq
1417+
self._freq = result.freq
14181418
return self
14191419

14201420
def __isub__(self, other):
@@ -1423,7 +1423,7 @@ def __isub__(self, other):
14231423

14241424
if not is_period_dtype(self.dtype):
14251425
# restore freq, which is invalidated by setitem
1426-
self._freq = result._freq
1426+
self._freq = result.freq
14271427
return self
14281428

14291429
# --------------------------------------------------------------

pandas/core/indexes/base.py

+20
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
PeriodDtype,
103103
)
104104
from pandas.core.dtypes.generic import (
105+
ABCDataFrame,
105106
ABCDatetimeIndex,
106107
ABCMultiIndex,
107108
ABCPeriodIndex,
@@ -116,6 +117,7 @@
116117
)
117118

118119
from pandas.core import (
120+
arraylike,
119121
missing,
120122
ops,
121123
)
@@ -844,6 +846,24 @@ def __array__(self, dtype=None) -> np.ndarray:
844846
"""
845847
return np.asarray(self._data, dtype=dtype)
846848

849+
def __array_ufunc__(self, ufunc: np.ufunc, method: str_t, *inputs, **kwargs):
850+
if any(isinstance(other, (ABCSeries, ABCDataFrame)) for other in inputs):
851+
return NotImplemented
852+
853+
result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
854+
self, ufunc, method, *inputs, **kwargs
855+
)
856+
if result is not NotImplemented:
857+
return result
858+
859+
new_inputs = [x if x is not self else x._values for x in inputs]
860+
result = getattr(ufunc, method)(*new_inputs, **kwargs)
861+
if ufunc.nout == 2:
862+
# i.e. np.divmod, np.modf, np.frexp
863+
return tuple(self.__array_wrap__(x) for x in result)
864+
865+
return self.__array_wrap__(result)
866+
847867
def __array_wrap__(self, result, context=None):
848868
"""
849869
Gets called after a ufunc and other functions.

pandas/core/indexes/datetimelike.py

-9
Original file line numberDiff line numberDiff line change
@@ -672,15 +672,6 @@ def insert(self, loc: int, item):
672672
# --------------------------------------------------------------------
673673
# NDArray-Like Methods
674674

675-
def __array_wrap__(self, result, context=None):
676-
"""
677-
Gets called after a ufunc and other functions.
678-
"""
679-
out = super().__array_wrap__(result, context=context)
680-
if isinstance(out, DatetimeTimedeltaMixin) and self.freq is not None:
681-
out = out._with_freq("infer")
682-
return out
683-
684675
@Appender(_index_shared_docs["take"] % _index_doc_kwargs)
685676
def take(self, indices, axis=0, allow_fill=True, fill_value=None, **kwargs):
686677
nv.validate_take((), kwargs)

pandas/tests/arithmetic/test_datetime64.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -2163,6 +2163,15 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
21632163
result -= tdi
21642164
tm.assert_index_equal(result, expected)
21652165

2166+
# DTA.__isub__ GH#43904
2167+
dta = dti._data.copy()
2168+
dta -= tdi
2169+
tm.assert_datetime_array_equal(dta, expected._data)
2170+
2171+
out = dti._data.copy()
2172+
np.subtract(out, tdi, out=out)
2173+
tm.assert_datetime_array_equal(out, expected._data)
2174+
21662175
msg = "cannot subtract .* from a TimedeltaArray"
21672176
with pytest.raises(TypeError, match=msg):
21682177
tdi -= dti
@@ -2172,10 +2181,14 @@ def test_dti_isub_tdi(self, tz_naive_fixture):
21722181
result -= tdi.values
21732182
tm.assert_index_equal(result, expected)
21742183

2175-
msg = "cannot subtract a datelike from a TimedeltaArray"
2184+
msg = "cannot subtract DatetimeArray from ndarray"
21762185
with pytest.raises(TypeError, match=msg):
21772186
tdi.values -= dti
21782187

2188+
msg = "cannot subtract a datelike from a TimedeltaArray"
2189+
with pytest.raises(TypeError, match=msg):
2190+
tdi._values -= dti
2191+
21792192
# -------------------------------------------------------------
21802193
# Binary Operations DatetimeIndex and datetime-like
21812194
# TODO: A couple other tests belong in this section. Move them in

0 commit comments

Comments
 (0)