Skip to content

Commit 9277f93

Browse files
authored
TST: fix pyarrow arithmetic xfails (#50877)
* TST: fix pyarrow arithmetic xfails * fix on older pyarrow
1 parent 755a99b commit 9277f93

File tree

1 file changed

+64
-12
lines changed

1 file changed

+64
-12
lines changed

pandas/tests/extension/test_arrow.py

+64-12
Original file line numberDiff line numberDiff line change
@@ -1062,14 +1062,29 @@ def _patch_combine(self, obj, other, op):
10621062
else:
10631063
expected_data = expected
10641064
original_dtype = obj.dtype
1065-
pa_array = pa.array(expected_data._values).cast(original_dtype.pyarrow_dtype)
1066-
pd_array = type(expected_data._values)(pa_array)
1065+
1066+
pa_expected = pa.array(expected_data._values)
1067+
1068+
if pa.types.is_duration(pa_expected.type):
1069+
# pyarrow sees sequence of datetime/timedelta objects and defaults
1070+
# to "us" but the non-pointwise op retains unit
1071+
unit = original_dtype.pyarrow_dtype.unit
1072+
if type(other) in [datetime, timedelta] and unit in ["s", "ms"]:
1073+
# pydatetime/pytimedelta objects have microsecond reso, so we
1074+
# take the higher reso of the original and microsecond. Note
1075+
# this matches what we would do with DatetimeArray/TimedeltaArray
1076+
unit = "us"
1077+
pa_expected = pa_expected.cast(f"duration[{unit}]")
1078+
else:
1079+
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
1080+
1081+
pd_expected = type(expected_data._values)(pa_expected)
10671082
if was_frame:
10681083
expected = pd.DataFrame(
1069-
pd_array, index=expected.index, columns=expected.columns
1084+
pd_expected, index=expected.index, columns=expected.columns
10701085
)
10711086
else:
1072-
expected = pd.Series(pd_array)
1087+
expected = pd.Series(pd_expected)
10731088
return expected
10741089

10751090
def _is_temporal_supported(self, opname, pa_dtype):
@@ -1149,7 +1164,14 @@ def test_arith_series_with_scalar(
11491164
if mark is not None:
11501165
request.node.add_marker(mark)
11511166

1152-
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
1167+
if (
1168+
(
1169+
all_arithmetic_operators == "__floordiv__"
1170+
and pa.types.is_integer(pa_dtype)
1171+
)
1172+
or pa.types.is_duration(pa_dtype)
1173+
or pa.types.is_timestamp(pa_dtype)
1174+
):
11531175
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11541176
# not upcast
11551177
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
@@ -1173,7 +1195,14 @@ def test_arith_frame_with_scalar(
11731195
if mark is not None:
11741196
request.node.add_marker(mark)
11751197

1176-
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
1198+
if (
1199+
(
1200+
all_arithmetic_operators == "__floordiv__"
1201+
and pa.types.is_integer(pa_dtype)
1202+
)
1203+
or pa.types.is_duration(pa_dtype)
1204+
or pa.types.is_timestamp(pa_dtype)
1205+
):
11771206
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11781207
# not upcast
11791208
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
@@ -1217,18 +1246,41 @@ def test_arith_series_with_array(
12171246
# since ser.iloc[0] is a python scalar
12181247
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
12191248

1220-
if pa.types.is_floating(pa_dtype) or (
1221-
pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__"
1249+
if (
1250+
pa.types.is_floating(pa_dtype)
1251+
or (
1252+
pa.types.is_integer(pa_dtype)
1253+
and all_arithmetic_operators != "__truediv__"
1254+
)
1255+
or pa.types.is_duration(pa_dtype)
1256+
or pa.types.is_timestamp(pa_dtype)
12221257
):
12231258
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
12241259
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
12251260

12261261
def test_add_series_with_extension_array(self, data, request):
12271262
pa_dtype = data.dtype.pyarrow_dtype
1228-
if not (
1229-
pa.types.is_integer(pa_dtype)
1230-
or pa.types.is_floating(pa_dtype)
1231-
or (not pa_version_under8p0 and pa.types.is_duration(pa_dtype))
1263+
1264+
if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
1265+
# i.e. timestamp, date, time, but not timedelta; these *should*
1266+
# raise when trying to add
1267+
ser = pd.Series(data)
1268+
if pa_version_under7p0:
1269+
msg = "Function add_checked has no kernel matching input types"
1270+
else:
1271+
msg = "Function 'add_checked' has no kernel matching input types"
1272+
with pytest.raises(NotImplementedError, match=msg):
1273+
# TODO: this is a pa.lib.ArrowNotImplementedError, might
1274+
# be better to reraise a TypeError; more consistent with
1275+
# non-pyarrow cases
1276+
ser + data
1277+
1278+
return
1279+
1280+
if (pa_version_under8p0 and pa.types.is_duration(pa_dtype)) or (
1281+
pa.types.is_binary(pa_dtype)
1282+
or pa.types.is_string(pa_dtype)
1283+
or pa.types.is_boolean(pa_dtype)
12321284
):
12331285
request.node.add_marker(
12341286
pytest.mark.xfail(

0 commit comments

Comments
 (0)