Skip to content

Commit 2d5d2de

Browse files
authored
API: dont do inference on object-dtype arithmetic results (#49999)
* API: dont do inference on object-dtype arithmetic results * suggest infer_objects * remove special case * de-special-case
1 parent e93ee07 commit 2d5d2de

File tree

7 files changed

+55
-21
lines changed

7 files changed

+55
-21
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ Other API changes
399399
- Passing a sequence containing a type that cannot be converted to :class:`Timedelta` to :func:`to_timedelta` or to the :class:`Series` or :class:`DataFrame` constructor with ``dtype="timedelta64[ns]"`` or to :class:`TimedeltaIndex` now raises ``TypeError`` instead of ``ValueError`` (:issue:`49525`)
400400
- Changed behavior of :class:`Index` constructor with sequence containing at least one ``NaT`` and everything else either ``None`` or ``NaN`` to infer ``datetime64[ns]`` dtype instead of ``object``, matching :class:`Series` behavior (:issue:`49340`)
401401
- :func:`read_stata` with parameter ``index_col`` set to ``None`` (the default) will now set the index on the returned :class:`DataFrame` to a :class:`RangeIndex` instead of a :class:`Int64Index` (:issue:`49745`)
402+
- Changed behavior of :class:`Index`, :class:`Series`, and :class:`DataFrame` arithmetic methods when working with object-dtypes, the results no longer do type inference on the result of the array operations, use ``result.infer_objects()`` to do type inference on the result (:issue:`49999`)
402403
- Changed behavior of :class:`Index` constructor with an object-dtype ``numpy.ndarray`` containing all-``bool`` values or all-complex values, this will now retain object dtype, consistent with the :class:`Series` behavior (:issue:`49594`)
403404
- Changed behavior of :meth:`DataFrame.shift` with ``axis=1``, an integer ``fill_value``, and homogeneous datetime-like dtype, this now fills new columns with integer dtypes instead of casting to datetimelike (:issue:`49842`)
404405
- Files are now closed when encountering an exception in :func:`read_json` (:issue:`49921`)

pandas/core/indexes/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6572,10 +6572,10 @@ def _logical_method(self, other, op):
65726572
def _construct_result(self, result, name):
65736573
if isinstance(result, tuple):
65746574
return (
6575-
Index._with_infer(result[0], name=name),
6576-
Index._with_infer(result[1], name=name),
6575+
Index(result[0], name=name, dtype=result[0].dtype),
6576+
Index(result[1], name=name, dtype=result[1].dtype),
65776577
)
6578-
return Index._with_infer(result, name=name)
6578+
return Index(result, name=name, dtype=result.dtype)
65796579

65806580
def _arith_method(self, other, op):
65816581
if (

pandas/core/ops/__init__.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -230,18 +230,27 @@ def align_method_FRAME(
230230

231231
def to_series(right):
232232
msg = "Unable to coerce to Series, length must be {req_len}: given {given_len}"
233+
234+
# pass dtype to avoid doing inference, which would break consistency
235+
# with Index/Series ops
236+
dtype = None
237+
if getattr(right, "dtype", None) == object:
238+
# can't pass right.dtype unconditionally as that would break on e.g.
239+
# datetime64[h] ndarray
240+
dtype = object
241+
233242
if axis is not None and left._get_axis_name(axis) == "index":
234243
if len(left.index) != len(right):
235244
raise ValueError(
236245
msg.format(req_len=len(left.index), given_len=len(right))
237246
)
238-
right = left._constructor_sliced(right, index=left.index)
247+
right = left._constructor_sliced(right, index=left.index, dtype=dtype)
239248
else:
240249
if len(left.columns) != len(right):
241250
raise ValueError(
242251
msg.format(req_len=len(left.columns), given_len=len(right))
243252
)
244-
right = left._constructor_sliced(right, index=left.columns)
253+
right = left._constructor_sliced(right, index=left.columns, dtype=dtype)
245254
return right
246255

247256
if isinstance(right, np.ndarray):
@@ -250,13 +259,25 @@ def to_series(right):
250259
right = to_series(right)
251260

252261
elif right.ndim == 2:
262+
# We need to pass dtype=right.dtype to retain object dtype
263+
# otherwise we lose consistency with Index and array ops
264+
dtype = None
265+
if getattr(right, "dtype", None) == object:
266+
# can't pass right.dtype unconditionally as that would break on e.g.
267+
# datetime64[h] ndarray
268+
dtype = object
269+
253270
if right.shape == left.shape:
254-
right = left._constructor(right, index=left.index, columns=left.columns)
271+
right = left._constructor(
272+
right, index=left.index, columns=left.columns, dtype=dtype
273+
)
255274

256275
elif right.shape[0] == left.shape[0] and right.shape[1] == 1:
257276
# Broadcast across columns
258277
right = np.broadcast_to(right, left.shape)
259-
right = left._constructor(right, index=left.index, columns=left.columns)
278+
right = left._constructor(
279+
right, index=left.index, columns=left.columns, dtype=dtype
280+
)
260281

261282
elif right.shape[1] == left.shape[1] and right.shape[0] == 1:
262283
# Broadcast along rows
@@ -406,7 +427,10 @@ def _maybe_align_series_as_frame(frame: DataFrame, series: Series, axis: AxisInt
406427
rvalues = rvalues.reshape(1, -1)
407428

408429
rvalues = np.broadcast_to(rvalues, frame.shape)
409-
return type(frame)(rvalues, index=frame.index, columns=frame.columns)
430+
# pass dtype to avoid doing inference
431+
return type(frame)(
432+
rvalues, index=frame.index, columns=frame.columns, dtype=rvalues.dtype
433+
)
410434

411435

412436
def flex_arith_method_FRAME(op):

pandas/core/series.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3000,9 +3000,10 @@ def _construct_result(
30003000
assert isinstance(res2, Series)
30013001
return (res1, res2)
30023002

3003-
# We do not pass dtype to ensure that the Series constructor
3004-
# does inference in the case where `result` has object-dtype.
3005-
out = self._constructor(result, index=self.index)
3003+
# TODO: result should always be ArrayLike, but this fails for some
3004+
# JSONArray tests
3005+
dtype = getattr(result, "dtype", None)
3006+
out = self._constructor(result, index=self.index, dtype=dtype)
30063007
out = out.__finalize__(self)
30073008

30083009
# Set the result's name after __finalize__ is called because __finalize__

pandas/tests/arithmetic/test_numeric.py

-6
Original file line numberDiff line numberDiff line change
@@ -1147,9 +1147,6 @@ def test_numarr_with_dtype_add_nan(self, dtype, box_with_array):
11471147

11481148
ser = tm.box_expected(ser, box)
11491149
expected = tm.box_expected(expected, box)
1150-
if box is Index and dtype is object:
1151-
# TODO: avoid this; match behavior with Series
1152-
expected = expected.astype(np.float64)
11531150

11541151
result = np.nan + ser
11551152
tm.assert_equal(result, expected)
@@ -1165,9 +1162,6 @@ def test_numarr_with_dtype_add_int(self, dtype, box_with_array):
11651162

11661163
ser = tm.box_expected(ser, box)
11671164
expected = tm.box_expected(expected, box)
1168-
if box is Index and dtype is object:
1169-
# TODO: avoid this; match behavior with Series
1170-
expected = expected.astype(np.int64)
11711165

11721166
result = 1 + ser
11731167
tm.assert_equal(result, expected)

pandas/tests/arithmetic/test_object.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def test_series_with_dtype_radd_timedelta(self, dtype):
187187
dtype=dtype,
188188
)
189189
expected = Series(
190-
[pd.Timedelta("4 days"), pd.Timedelta("5 days"), pd.Timedelta("6 days")]
190+
[pd.Timedelta("4 days"), pd.Timedelta("5 days"), pd.Timedelta("6 days")],
191+
dtype=dtype,
191192
)
192193

193194
result = pd.Timedelta("3 days") + ser
@@ -227,7 +228,9 @@ def test_mixed_timezone_series_ops_object(self):
227228
name="xxx",
228229
)
229230
assert ser2.dtype == object
230-
exp = Series([pd.Timedelta("2 days"), pd.Timedelta("4 days")], name="xxx")
231+
exp = Series(
232+
[pd.Timedelta("2 days"), pd.Timedelta("4 days")], name="xxx", dtype=object
233+
)
231234
tm.assert_series_equal(ser2 - ser, exp)
232235
tm.assert_series_equal(ser - ser2, -exp)
233236

@@ -238,7 +241,11 @@ def test_mixed_timezone_series_ops_object(self):
238241
)
239242
assert ser.dtype == object
240243

241-
exp = Series([pd.Timedelta("01:30:00"), pd.Timedelta("02:30:00")], name="xxx")
244+
exp = Series(
245+
[pd.Timedelta("01:30:00"), pd.Timedelta("02:30:00")],
246+
name="xxx",
247+
dtype=object,
248+
)
242249
tm.assert_series_equal(ser + pd.Timedelta("00:30:00"), exp)
243250
tm.assert_series_equal(pd.Timedelta("00:30:00") + ser, exp)
244251

pandas/tests/arithmetic/test_timedelta64.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ def test_td64arr_addsub_anchored_offset_arraylike(self, obox, box_with_array):
13941394
# ------------------------------------------------------------------
13951395
# Unsorted
13961396

1397-
def test_td64arr_add_sub_object_array(self, box_with_array):
1397+
def test_td64arr_add_sub_object_array(self, box_with_array, using_array_manager):
13981398
box = box_with_array
13991399
xbox = np.ndarray if box is pd.array else box
14001400

@@ -1410,6 +1410,11 @@ def test_td64arr_add_sub_object_array(self, box_with_array):
14101410
[Timedelta(days=2), Timedelta(days=4), Timestamp("2000-01-07")]
14111411
)
14121412
expected = tm.box_expected(expected, xbox)
1413+
if not using_array_manager:
1414+
# TODO: avoid mismatched behavior. This occurs bc inference
1415+
# can happen within TimedeltaArray method, which means results
1416+
# depend on whether we split blocks.
1417+
expected = expected.astype(object)
14131418
tm.assert_equal(result, expected)
14141419

14151420
msg = "unsupported operand type|cannot subtract a datelike"
@@ -1422,6 +1427,8 @@ def test_td64arr_add_sub_object_array(self, box_with_array):
14221427

14231428
expected = pd.Index([Timedelta(0), Timedelta(0), Timestamp("2000-01-01")])
14241429
expected = tm.box_expected(expected, xbox)
1430+
if not using_array_manager:
1431+
expected = expected.astype(object)
14251432
tm.assert_equal(result, expected)
14261433

14271434

0 commit comments

Comments
 (0)