Skip to content

Commit 0e55e3a

Browse files
committed
TST: fix pyarrow arithmetic xfails
1 parent 0e484ed commit 0e55e3a

File tree

1 file changed

+61
-12
lines changed

1 file changed

+61
-12
lines changed

pandas/tests/extension/test_arrow.py

+61-12
Original file line numberDiff line numberDiff line change
@@ -1025,14 +1025,29 @@ def _patch_combine(self, obj, other, op):
10251025
else:
10261026
expected_data = expected
10271027
original_dtype = obj.dtype
1028-
pa_array = pa.array(expected_data._values).cast(original_dtype.pyarrow_dtype)
1029-
pd_array = type(expected_data._values)(pa_array)
1028+
1029+
pa_expected = pa.array(expected_data._values)
1030+
1031+
if pa.types.is_duration(pa_expected.type):
1032+
# pyarrow sees sequence of datetime/timedelta objects and defaults
1033+
# to "us" but the non-pointwise op retains unit
1034+
unit = original_dtype.pyarrow_dtype.unit
1035+
if type(other) in [datetime, timedelta] and unit in ["s", "ms"]:
1036+
# pydatetime/pytimedelta objects have microsecond reso, so we
1037+
# take the higher reso of the original and microsecond. Note
1038+
# this matches what we would do with DatetimeArray/TimedeltaArray
1039+
unit = "us"
1040+
pa_expected = pa_expected.cast(f"duration[{unit}]")
1041+
else:
1042+
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
1043+
1044+
pd_expected = type(expected_data._values)(pa_expected)
10301045
if was_frame:
10311046
expected = pd.DataFrame(
1032-
pd_array, index=expected.index, columns=expected.columns
1047+
pd_expected, index=expected.index, columns=expected.columns
10331048
)
10341049
else:
1035-
expected = pd.Series(pd_array)
1050+
expected = pd.Series(pd_expected)
10361051
return expected
10371052

10381053
def _is_temporal_supported(self, opname, pa_dtype):
@@ -1112,7 +1127,14 @@ def test_arith_series_with_scalar(
11121127
if mark is not None:
11131128
request.node.add_marker(mark)
11141129

1115-
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
1130+
if (
1131+
(
1132+
all_arithmetic_operators == "__floordiv__"
1133+
and pa.types.is_integer(pa_dtype)
1134+
)
1135+
or pa.types.is_duration(pa_dtype)
1136+
or pa.types.is_timestamp(pa_dtype)
1137+
):
11161138
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11171139
# not upcast
11181140
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
@@ -1136,7 +1158,14 @@ def test_arith_frame_with_scalar(
11361158
if mark is not None:
11371159
request.node.add_marker(mark)
11381160

1139-
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
1161+
if (
1162+
(
1163+
all_arithmetic_operators == "__floordiv__"
1164+
and pa.types.is_integer(pa_dtype)
1165+
)
1166+
or pa.types.is_duration(pa_dtype)
1167+
or pa.types.is_timestamp(pa_dtype)
1168+
):
11401169
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11411170
# not upcast
11421171
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
@@ -1180,18 +1209,38 @@ def test_arith_series_with_array(
11801209
# since ser.iloc[0] is a python scalar
11811210
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
11821211

1183-
if pa.types.is_floating(pa_dtype) or (
1184-
pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__"
1212+
if (
1213+
pa.types.is_floating(pa_dtype)
1214+
or (
1215+
pa.types.is_integer(pa_dtype)
1216+
and all_arithmetic_operators != "__truediv__"
1217+
)
1218+
or pa.types.is_duration(pa_dtype)
1219+
or pa.types.is_timestamp(pa_dtype)
11851220
):
11861221
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
11871222
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
11881223

11891224
def test_add_series_with_extension_array(self, data, request):
11901225
pa_dtype = data.dtype.pyarrow_dtype
1191-
if not (
1192-
pa.types.is_integer(pa_dtype)
1193-
or pa.types.is_floating(pa_dtype)
1194-
or (not pa_version_under8p0 and pa.types.is_duration(pa_dtype))
1226+
1227+
if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
1228+
# i.e. timestamp, date, time, but not timedelta; these *should*
1229+
# raise when trying to add
1230+
ser = pd.Series(data)
1231+
msg = "Function 'add_checked' has no kernel matching input types"
1232+
with pytest.raises(NotImplementedError, match=msg):
1233+
# TODO: this is a pa.lib.ArrowNotImplementedError, might
1234+
# be better to reraise a TypeError; more consistent with
1235+
# non-pyarrow cases
1236+
ser + data
1237+
1238+
return
1239+
1240+
if (pa_version_under8p0 and pa.types.is_duration(pa_dtype)) or (
1241+
pa.types.is_binary(pa_dtype)
1242+
or pa.types.is_string(pa_dtype)
1243+
or pa.types.is_boolean(pa_dtype)
11951244
):
11961245
request.node.add_marker(
11971246
pytest.mark.xfail(

0 commit comments

Comments
 (0)