Skip to content

Commit bdcb5da

Browse files
authored
BUG: 2D DTA/TDA arithmetic with object-dtype (#32185)
1 parent d219c2c commit bdcb5da

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

pandas/core/arrays/datetimelike.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pandas.core.algorithms import checked_add_with_arr, take, unique1d, value_counts
4343
from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin
4444
import pandas.core.common as com
45+
from pandas.core.construction import array, extract_array
4546
from pandas.core.indexers import check_array_indexer
4647
from pandas.core.ops.common import unpack_zerodim_and_defer
4748
from pandas.core.ops.invalid import invalid_comparison, make_invalid_op
@@ -623,7 +624,7 @@ def astype(self, dtype, copy=True):
623624
dtype = pandas_dtype(dtype)
624625

625626
if is_object_dtype(dtype):
626-
return self._box_values(self.asi8)
627+
return self._box_values(self.asi8.ravel()).reshape(self.shape)
627628
elif is_string_dtype(dtype) and not is_categorical_dtype(dtype):
628629
return self._format_native_types()
629630
elif is_integer_dtype(dtype):
@@ -1256,19 +1257,13 @@ def _addsub_object_array(self, other: np.ndarray, op):
12561257
PerformanceWarning,
12571258
)
12581259

1259-
# For EA self.astype('O') returns a numpy array, not an Index
1260-
left = self.astype("O")
1260+
# Caller is responsible for broadcasting if necessary
1261+
assert self.shape == other.shape, (self.shape, other.shape)
12611262

1262-
res_values = op(left, np.array(other))
1263-
kwargs = {}
1264-
if not is_period_dtype(self):
1265-
kwargs["freq"] = "infer"
1266-
try:
1267-
res = type(self)._from_sequence(res_values, **kwargs)
1268-
except ValueError:
1269-
# e.g. we've passed a Timestamp to TimedeltaArray
1270-
res = res_values
1271-
return res
1263+
res_values = op(self.astype("O"), np.array(other))
1264+
result = array(res_values.ravel())
1265+
result = extract_array(result, extract_numpy=True).reshape(self.shape)
1266+
return result
12721267

12731268
def _time_shift(self, periods, freq=None):
12741269
"""

pandas/tests/arithmetic/test_datetime64.py

+41
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
date_range,
2828
)
2929
import pandas._testing as tm
30+
from pandas.core.arrays import DatetimeArray, TimedeltaArray
3031
from pandas.core.ops import roperator
3132
from pandas.tests.arithmetic.common import (
3233
assert_invalid_addsub_type,
@@ -956,6 +957,18 @@ def test_dt64arr_sub_NaT(self, box_with_array):
956957
# -------------------------------------------------------------
957958
# Subtraction of datetime-like array-like
958959

960+
def test_dt64arr_sub_dt64object_array(self, box_with_array, tz_naive_fixture):
961+
dti = pd.date_range("2016-01-01", periods=3, tz=tz_naive_fixture)
962+
expected = dti - dti
963+
964+
obj = tm.box_expected(dti, box_with_array)
965+
expected = tm.box_expected(expected, box_with_array)
966+
967+
warn = PerformanceWarning if box_with_array is not pd.DataFrame else None
968+
with tm.assert_produces_warning(warn):
969+
result = obj - obj.astype(object)
970+
tm.assert_equal(result, expected)
971+
959972
def test_dt64arr_naive_sub_dt64ndarray(self, box_with_array):
960973
dti = pd.date_range("2016-01-01", periods=3, tz=None)
961974
dt64vals = dti.values
@@ -2395,3 +2408,31 @@ def test_shift_months(years, months):
23952408
raw = [x + pd.offsets.DateOffset(years=years, months=months) for x in dti]
23962409
expected = DatetimeIndex(raw)
23972410
tm.assert_index_equal(actual, expected)
2411+
2412+
2413+
def test_dt64arr_addsub_object_dtype_2d():
2414+
# block-wise DataFrame operations will require operating on 2D
2415+
# DatetimeArray/TimedeltaArray, so check that specifically.
2416+
dti = pd.date_range("1994-02-13", freq="2W", periods=4)
2417+
dta = dti._data.reshape((4, 1))
2418+
2419+
other = np.array([[pd.offsets.Day(n)] for n in range(4)])
2420+
assert other.shape == dta.shape
2421+
2422+
with tm.assert_produces_warning(PerformanceWarning):
2423+
result = dta + other
2424+
with tm.assert_produces_warning(PerformanceWarning):
2425+
expected = (dta[:, 0] + other[:, 0]).reshape(-1, 1)
2426+
2427+
assert isinstance(result, DatetimeArray)
2428+
assert result.freq is None
2429+
tm.assert_numpy_array_equal(result._data, expected._data)
2430+
2431+
with tm.assert_produces_warning(PerformanceWarning):
2432+
# Case where we expect to get a TimedeltaArray back
2433+
result2 = dta - dta.astype(object)
2434+
2435+
assert isinstance(result2, TimedeltaArray)
2436+
assert result2.shape == (4, 1)
2437+
assert result2.freq is None
2438+
assert (result2.asi8 == 0).all()

pandas/tests/arithmetic/test_timedelta64.py

+14
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,20 @@ def test_tda_add_sub_index(self):
532532
expected = tdi - tdi
533533
tm.assert_index_equal(result, expected)
534534

535+
def test_tda_add_dt64_object_array(self, box_df_fail, tz_naive_fixture):
536+
# Result should be cast back to DatetimeArray
537+
dti = pd.date_range("2016-01-01", periods=3, tz=tz_naive_fixture)
538+
dti._set_freq(None)
539+
tdi = dti - dti
540+
541+
obj = tm.box_expected(tdi, box_df_fail)
542+
other = tm.box_expected(dti, box_df_fail)
543+
544+
warn = PerformanceWarning if box_df_fail is not pd.DataFrame else None
545+
with tm.assert_produces_warning(warn):
546+
result = obj + other.astype(object)
547+
tm.assert_equal(result, other)
548+
535549
# -------------------------------------------------------------
536550
# Binary operations TimedeltaIndex and timedelta-like
537551

0 commit comments

Comments
 (0)