Skip to content

Commit 1187703

Browse files
committed
TST: fix pyarrow arithmetic xfails
1 parent 1128f5e commit 1187703

File tree

1 file changed

+61
-12
lines changed

1 file changed

+61
-12
lines changed

pandas/tests/extension/test_arrow.py

+61-12
Original file line numberDiff line numberDiff line change
@@ -1010,14 +1010,29 @@ def _patch_combine(self, obj, other, op):
10101010
else:
10111011
expected_data = expected
10121012
original_dtype = obj.dtype
1013-
pa_array = pa.array(expected_data._values).cast(original_dtype.pyarrow_dtype)
1014-
pd_array = type(expected_data._values)(pa_array)
1013+
1014+
pa_expected = pa.array(expected_data._values)
1015+
1016+
if pa.types.is_duration(pa_expected.type):
1017+
# pyarrow sees sequence of datetime/timedelta objects and defaults
1018+
# to "us" but the non-pointwise op retains unit
1019+
unit = original_dtype.pyarrow_dtype.unit
1020+
if type(other) in [datetime, timedelta] and unit in ["s", "ms"]:
1021+
# pydatetime/pytimedelta objects have microsecond reso, so we
1022+
# take the higher reso of the original and microsecond. Note
1023+
# this matches what we would do with DatetimeArray/TimedeltaArray
1024+
unit = "us"
1025+
pa_expected = pa_expected.cast(f"duration[{unit}]")
1026+
else:
1027+
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
1028+
1029+
pd_expected = type(expected_data._values)(pa_expected)
10151030
if was_frame:
10161031
expected = pd.DataFrame(
1017-
pd_array, index=expected.index, columns=expected.columns
1032+
pd_expected, index=expected.index, columns=expected.columns
10181033
)
10191034
else:
1020-
expected = pd.Series(pd_array)
1035+
expected = pd.Series(pd_expected)
10211036
return expected
10221037

10231038
def _is_temporal_supported(self, opname, pa_dtype):
@@ -1097,7 +1112,14 @@ def test_arith_series_with_scalar(
10971112
if mark is not None:
10981113
request.node.add_marker(mark)
10991114

1100-
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
1115+
if (
1116+
(
1117+
all_arithmetic_operators == "__floordiv__"
1118+
and pa.types.is_integer(pa_dtype)
1119+
)
1120+
or pa.types.is_duration(pa_dtype)
1121+
or pa.types.is_timestamp(pa_dtype)
1122+
):
11011123
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11021124
# not upcast
11031125
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
@@ -1121,7 +1143,14 @@ def test_arith_frame_with_scalar(
11211143
if mark is not None:
11221144
request.node.add_marker(mark)
11231145

1124-
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
1146+
if (
1147+
(
1148+
all_arithmetic_operators == "__floordiv__"
1149+
and pa.types.is_integer(pa_dtype)
1150+
)
1151+
or pa.types.is_duration(pa_dtype)
1152+
or pa.types.is_timestamp(pa_dtype)
1153+
):
11251154
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11261155
# not upcast
11271156
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
@@ -1165,18 +1194,38 @@ def test_arith_series_with_array(
11651194
# since ser.iloc[0] is a python scalar
11661195
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
11671196

1168-
if pa.types.is_floating(pa_dtype) or (
1169-
pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__"
1197+
if (
1198+
pa.types.is_floating(pa_dtype)
1199+
or (
1200+
pa.types.is_integer(pa_dtype)
1201+
and all_arithmetic_operators != "__truediv__"
1202+
)
1203+
or pa.types.is_duration(pa_dtype)
1204+
or pa.types.is_timestamp(pa_dtype)
11701205
):
11711206
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
11721207
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
11731208

11741209
def test_add_series_with_extension_array(self, data, request):
11751210
pa_dtype = data.dtype.pyarrow_dtype
1176-
if not (
1177-
pa.types.is_integer(pa_dtype)
1178-
or pa.types.is_floating(pa_dtype)
1179-
or (not pa_version_under8p0 and pa.types.is_duration(pa_dtype))
1211+
1212+
if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
1213+
# i.e. timestamp, date, time, but not timedelta; these *should*
1214+
# raise when trying to add
1215+
ser = pd.Series(data)
1216+
msg = "Function 'add_checked' has no kernel matching input types"
1217+
with pytest.raises(NotImplementedError, match=msg):
1218+
# TODO: this is a pa.lib.ArrowNotImplementedError, might
1219+
# be better to reraise a TypeError; more consistent with
1220+
# non-pyarrow cases
1221+
ser + data
1222+
1223+
return
1224+
1225+
if (pa_version_under8p0 and pa.types.is_duration(pa_dtype)) or (
1226+
pa.types.is_binary(pa_dtype)
1227+
or pa.types.is_string(pa_dtype)
1228+
or pa.types.is_boolean(pa_dtype)
11801229
):
11811230
request.node.add_marker(
11821231
pytest.mark.xfail(

0 commit comments

Comments
 (0)