-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
TST: simplify pyarrow tests, make mode work with temporal dtypes #50688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
187023a
fa233a7
d0daeac
72bed6f
83725ed
9cf7fe1
32224d7
1a515a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -333,15 +333,39 @@ def test_from_sequence_of_strings_pa_array(self, data, request): | |
|
||
|
||
class TestGetitemTests(base.BaseGetitemTests): | ||
@pytest.mark.xfail( | ||
reason=( | ||
"data.dtype.type return pyarrow.DataType " | ||
"but this (intentionally) returns " | ||
"Python scalars or pd.NA" | ||
) | ||
) | ||
def test_getitem_scalar(self, data): | ||
super().test_getitem_scalar(data) | ||
# In the base class we expect data.dtype.type; but this (intentionally) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mroeschke might it make sense to reconsider this? im not sure what context the current There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, yeah I am not exactly sure how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the main use case that comes to mind is if i want to do e.g. |
||
# returns Python scalars or pd.NA | ||
pa_type = data._data.type | ||
if pa.types.is_integer(pa_type): | ||
exp_type = int | ||
elif pa.types.is_floating(pa_type): | ||
exp_type = float | ||
elif pa.types.is_string(pa_type): | ||
exp_type = str | ||
elif pa.types.is_binary(pa_type): | ||
exp_type = bytes | ||
elif pa.types.is_boolean(pa_type): | ||
exp_type = bool | ||
elif pa.types.is_duration(pa_type): | ||
exp_type = timedelta | ||
elif pa.types.is_timestamp(pa_type): | ||
if pa_type.unit == "ns": | ||
exp_type = pd.Timestamp | ||
else: | ||
exp_type = datetime | ||
elif pa.types.is_date(pa_type): | ||
exp_type = date | ||
elif pa.types.is_time(pa_type): | ||
exp_type = time | ||
else: | ||
raise NotImplementedError(data.dtype) | ||
|
||
result = data[0] | ||
assert isinstance(result, exp_type), type(result) | ||
|
||
result = pd.Series(data)[0] | ||
assert isinstance(result, exp_type), type(result) | ||
|
||
|
||
class TestBaseAccumulateTests(base.BaseAccumulateTests): | ||
|
@@ -1054,70 +1078,83 @@ def _patch_combine(self, obj, other, op): | |
expected = pd.Series(pd_array) | ||
return expected | ||
|
||
def test_arith_series_with_scalar( | ||
self, data, all_arithmetic_operators, request, monkeypatch | ||
): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
|
||
arrow_temporal_supported = not pa_version_under8p0 and ( | ||
all_arithmetic_operators in ("__add__", "__radd__") | ||
def _is_temporal_supported(self, opname, pa_dtype): | ||
return not pa_version_under8p0 and ( | ||
opname in ("__add__", "__radd__") | ||
and pa.types.is_duration(pa_dtype) | ||
or all_arithmetic_operators in ("__sub__", "__rsub__") | ||
or opname in ("__sub__", "__rsub__") | ||
and pa.types.is_temporal(pa_dtype) | ||
) | ||
if all_arithmetic_operators == "__rmod__" and ( | ||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) | ||
): | ||
pytest.skip("Skip testing Python string formatting") | ||
elif all_arithmetic_operators in { | ||
|
||
def _get_scalar_exception(self, opname, pa_dtype): | ||
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) | ||
if opname in { | ||
"__mod__", | ||
"__rmod__", | ||
}: | ||
self.series_scalar_exc = NotImplementedError | ||
exc = NotImplementedError | ||
elif arrow_temporal_supported: | ||
self.series_scalar_exc = None | ||
elif not ( | ||
pa.types.is_floating(pa_dtype) | ||
or pa.types.is_integer(pa_dtype) | ||
or arrow_temporal_supported | ||
): | ||
self.series_scalar_exc = pa.ArrowNotImplementedError | ||
exc = None | ||
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): | ||
exc = pa.ArrowNotImplementedError | ||
else: | ||
self.series_scalar_exc = None | ||
exc = None | ||
return exc | ||
|
||
def _get_arith_xfail_marker(self, opname, pa_dtype): | ||
mark = None | ||
|
||
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) | ||
|
||
if ( | ||
all_arithmetic_operators == "__rpow__" | ||
opname == "__rpow__" | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and not pa_version_under6p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=( | ||
f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " | ||
f"for {pa_dtype}" | ||
) | ||
mark = pytest.mark.xfail( | ||
reason=( | ||
f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " | ||
f"for {pa_dtype}" | ||
) | ||
) | ||
elif arrow_temporal_supported: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=TypeError, | ||
reason=( | ||
f"{all_arithmetic_operators} not supported between" | ||
f"pd.NA and {pa_dtype} Python scalar" | ||
), | ||
) | ||
mark = pytest.mark.xfail( | ||
raises=TypeError, | ||
reason=( | ||
f"{opname} not supported between" | ||
f"pd.NA and {pa_dtype} Python scalar" | ||
), | ||
) | ||
elif ( | ||
all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} | ||
opname in {"__rtruediv__", "__rfloordiv__"} | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and not pa_version_under6p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
mark = pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
|
||
return mark | ||
|
||
def test_arith_series_with_scalar( | ||
self, data, all_arithmetic_operators, request, monkeypatch | ||
): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
|
||
if all_arithmetic_operators == "__rmod__" and ( | ||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) | ||
): | ||
pytest.skip("Skip testing Python string formatting") | ||
|
||
self.series_scalar_exc = self._get_scalar_exception( | ||
all_arithmetic_operators, pa_dtype | ||
) | ||
|
||
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) | ||
if mark is not None: | ||
request.node.add_marker(mark) | ||
|
||
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype): | ||
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does | ||
# not upcast | ||
|
@@ -1129,61 +1166,19 @@ def test_arith_frame_with_scalar( | |
): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
|
||
arrow_temporal_supported = not pa_version_under8p0 and ( | ||
all_arithmetic_operators in ("__add__", "__radd__") | ||
and pa.types.is_duration(pa_dtype) | ||
or all_arithmetic_operators in ("__sub__", "__rsub__") | ||
and pa.types.is_temporal(pa_dtype) | ||
) | ||
if all_arithmetic_operators == "__rmod__" and ( | ||
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) | ||
): | ||
pytest.skip("Skip testing Python string formatting") | ||
elif all_arithmetic_operators in { | ||
"__mod__", | ||
"__rmod__", | ||
}: | ||
self.frame_scalar_exc = NotImplementedError | ||
elif arrow_temporal_supported: | ||
self.frame_scalar_exc = None | ||
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): | ||
self.frame_scalar_exc = pa.ArrowNotImplementedError | ||
else: | ||
self.frame_scalar_exc = None | ||
if ( | ||
all_arithmetic_operators == "__rpow__" | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and not pa_version_under6p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=( | ||
f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " | ||
f"for {pa_dtype}" | ||
) | ||
) | ||
) | ||
elif arrow_temporal_supported: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=TypeError, | ||
reason=( | ||
f"{all_arithmetic_operators} not supported between" | ||
f"pd.NA and {pa_dtype} Python scalar" | ||
), | ||
) | ||
) | ||
elif ( | ||
all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and not pa_version_under6p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
) | ||
|
||
self.frame_scalar_exc = self._get_scalar_exception( | ||
all_arithmetic_operators, pa_dtype | ||
) | ||
|
||
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) | ||
if mark is not None: | ||
request.node.add_marker(mark) | ||
|
||
if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype): | ||
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does | ||
# not upcast | ||
|
@@ -1195,37 +1190,11 @@ def test_arith_series_with_array( | |
): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
|
||
arrow_temporal_supported = not pa_version_under8p0 and ( | ||
all_arithmetic_operators in ("__add__", "__radd__") | ||
and pa.types.is_duration(pa_dtype) | ||
or all_arithmetic_operators in ("__sub__", "__rsub__") | ||
and pa.types.is_temporal(pa_dtype) | ||
self.series_array_exc = self._get_scalar_exception( | ||
all_arithmetic_operators, pa_dtype | ||
) | ||
if all_arithmetic_operators in { | ||
"__mod__", | ||
"__rmod__", | ||
}: | ||
self.series_array_exc = NotImplementedError | ||
elif arrow_temporal_supported: | ||
self.series_array_exc = None | ||
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)): | ||
self.series_array_exc = pa.ArrowNotImplementedError | ||
else: | ||
self.series_array_exc = None | ||
|
||
if ( | ||
all_arithmetic_operators == "__rpow__" | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and not pa_version_under6p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
reason=( | ||
f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " | ||
f"for {pa_dtype}" | ||
) | ||
) | ||
) | ||
elif ( | ||
all_arithmetic_operators | ||
in ( | ||
"__sub__", | ||
|
@@ -1243,32 +1212,17 @@ def test_arith_series_with_array( | |
), | ||
) | ||
) | ||
elif arrow_temporal_supported: | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=TypeError, | ||
reason=( | ||
f"{all_arithmetic_operators} not supported between" | ||
f"pd.NA and {pa_dtype} Python scalar" | ||
), | ||
) | ||
) | ||
elif ( | ||
all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"} | ||
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)) | ||
and not pa_version_under6p0 | ||
): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowInvalid, | ||
reason="divide by 0", | ||
) | ||
) | ||
|
||
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) | ||
if mark is not None: | ||
request.node.add_marker(mark) | ||
|
||
op_name = all_arithmetic_operators | ||
ser = pd.Series(data) | ||
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray | ||
# since ser.iloc[0] is a python scalar | ||
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype)) | ||
|
||
if pa.types.is_floating(pa_dtype) or ( | ||
pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__" | ||
): | ||
|
@@ -1364,6 +1318,25 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): | |
@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]]) | ||
def test_quantile(data, interpolation, quantile, request): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
|
||
data = data.take([0, 0, 0]) | ||
ser = pd.Series(data) | ||
|
||
if ( | ||
pa.types.is_string(pa_dtype) | ||
or pa.types.is_binary(pa_dtype) | ||
or pa.types.is_boolean(pa_dtype) | ||
): | ||
# For string, bytes, and bool, we don't *expect* to have quantile work | ||
# Note this matches the non-pyarrow behavior | ||
if pa_version_under7p0: | ||
msg = r"Function quantile has no kernel matching input types \(.*\)" | ||
else: | ||
msg = r"Function 'quantile' has no kernel matching input types \(.*\)" | ||
with pytest.raises(pa.ArrowNotImplementedError, match=msg): | ||
ser.quantile(q=quantile, interpolation=interpolation) | ||
return | ||
|
||
if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
|
@@ -1398,11 +1371,7 @@ def test_quantile(data, interpolation, quantile, request): | |
) | ||
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request): | ||
pa_dtype = data_for_grouping.dtype.pyarrow_dtype | ||
if ( | ||
pa.types.is_temporal(pa_dtype) | ||
or pa.types.is_string(pa_dtype) | ||
or pa.types.is_binary(pa_dtype) | ||
): | ||
if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype): | ||
request.node.add_marker( | ||
pytest.mark.xfail( | ||
raises=pa.ArrowNotImplementedError, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just
self._data.cast
to int type, continue the logic below, and thenmost_common.cast
back?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could. i preferred this way since it was self-contained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect going through
pa.array.cast
may be more performant thanastype
. Mind checking?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated + green
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like _mode here is still using astype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, in some cases this needs an int32 and in some cases an int64, which makes the .cast version extra-complicated