Skip to content

TST: simplify pyarrow tests, make mode work with temporal dtypes #50688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,16 @@ def _mode(self: ArrowExtensionArrayT, dropna: bool = True) -> ArrowExtensionArra
"""
if pa_version_under6p0:
raise NotImplementedError("mode only supported for pyarrow version >= 6.0")

pa_type = self._data.type
if pa.types.is_temporal(pa_type):
nbits = pa_type.bit_width
dtype = f"int{nbits}[pyarrow]"
obj = cast(ArrowExtensionArrayT, self.astype(dtype, copy=False))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just self._data.cast to int type, continue the logic below, and then most_common.cast back?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could. i preferred this way since it was self-contained

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect going through pa.array.cast may be more performant than astype. Mind checking?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated + green

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like _mode here is still using astype

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, in some cases this needs an int32 and in some cases an int64, which makes the .cast version extra-complicated

result = obj._mode(dropna=dropna)
out = result.astype(self.dtype, copy=False)
return cast(ArrowExtensionArrayT, out)

modes = pc.mode(self._data, pc.count_distinct(self._data).as_py())
values = modes.field(0)
counts = modes.field(1)
Expand Down
283 changes: 126 additions & 157 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,39 @@ def test_from_sequence_of_strings_pa_array(self, data, request):


class TestGetitemTests(base.BaseGetitemTests):
@pytest.mark.xfail(
reason=(
"data.dtype.type return pyarrow.DataType "
"but this (intentionally) returns "
"Python scalars or pd.NA"
)
)
def test_getitem_scalar(self, data):
super().test_getitem_scalar(data)
# In the base class we expect data.dtype.type; but this (intentionally)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mroeschke might it make sense to reconsider this? im not sure what context the current PyArrow.lib.DataType is really useful

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, yeah I am not exactly sure how ExtensionDtype.type is used internally, but I think it would make sense if it returns the native python type. Would be nice if there was a nice way to derive that from numpy_dtype

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main use case that comes to mind is if i want to do e.g. dtype.type(0) to get an object of the correct type

# returns Python scalars or pd.NA
pa_type = data._data.type
if pa.types.is_integer(pa_type):
exp_type = int
elif pa.types.is_floating(pa_type):
exp_type = float
elif pa.types.is_string(pa_type):
exp_type = str
elif pa.types.is_binary(pa_type):
exp_type = bytes
elif pa.types.is_boolean(pa_type):
exp_type = bool
elif pa.types.is_duration(pa_type):
exp_type = timedelta
elif pa.types.is_timestamp(pa_type):
if pa_type.unit == "ns":
exp_type = pd.Timestamp
else:
exp_type = datetime
elif pa.types.is_date(pa_type):
exp_type = date
elif pa.types.is_time(pa_type):
exp_type = time
else:
raise NotImplementedError(data.dtype)

result = data[0]
assert isinstance(result, exp_type), type(result)

result = pd.Series(data)[0]
assert isinstance(result, exp_type), type(result)


class TestBaseAccumulateTests(base.BaseAccumulateTests):
Expand Down Expand Up @@ -1054,70 +1078,83 @@ def _patch_combine(self, obj, other, op):
expected = pd.Series(pd_array)
return expected

def test_arith_series_with_scalar(
self, data, all_arithmetic_operators, request, monkeypatch
):
pa_dtype = data.dtype.pyarrow_dtype

arrow_temporal_supported = not pa_version_under8p0 and (
all_arithmetic_operators in ("__add__", "__radd__")
def _is_temporal_supported(self, opname, pa_dtype):
return not pa_version_under8p0 and (
opname in ("__add__", "__radd__")
and pa.types.is_duration(pa_dtype)
or all_arithmetic_operators in ("__sub__", "__rsub__")
or opname in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
)
if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")
elif all_arithmetic_operators in {

def _get_scalar_exception(self, opname, pa_dtype):
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
if opname in {
"__mod__",
"__rmod__",
}:
self.series_scalar_exc = NotImplementedError
exc = NotImplementedError
elif arrow_temporal_supported:
self.series_scalar_exc = None
elif not (
pa.types.is_floating(pa_dtype)
or pa.types.is_integer(pa_dtype)
or arrow_temporal_supported
):
self.series_scalar_exc = pa.ArrowNotImplementedError
exc = None
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
exc = pa.ArrowNotImplementedError
else:
self.series_scalar_exc = None
exc = None
return exc

def _get_arith_xfail_marker(self, opname, pa_dtype):
mark = None

arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)

if (
all_arithmetic_operators == "__rpow__"
opname == "__rpow__"
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and not pa_version_under6p0
):
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL "
f"for {pa_dtype}"
)
mark = pytest.mark.xfail(
reason=(
f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL "
f"for {pa_dtype}"
)
)
elif arrow_temporal_supported:
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
f"{all_arithmetic_operators} not supported between"
f"pd.NA and {pa_dtype} Python scalar"
),
)
mark = pytest.mark.xfail(
raises=TypeError,
reason=(
f"{opname} not supported between"
f"pd.NA and {pa_dtype} Python scalar"
),
)
elif (
all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"}
opname in {"__rtruediv__", "__rfloordiv__"}
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and not pa_version_under6p0
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
mark = pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)

return mark

def test_arith_series_with_scalar(
self, data, all_arithmetic_operators, request, monkeypatch
):
pa_dtype = data.dtype.pyarrow_dtype

if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")

self.series_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)

if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
# not upcast
Expand All @@ -1129,61 +1166,19 @@ def test_arith_frame_with_scalar(
):
pa_dtype = data.dtype.pyarrow_dtype

arrow_temporal_supported = not pa_version_under8p0 and (
all_arithmetic_operators in ("__add__", "__radd__")
and pa.types.is_duration(pa_dtype)
or all_arithmetic_operators in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
)
if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")
elif all_arithmetic_operators in {
"__mod__",
"__rmod__",
}:
self.frame_scalar_exc = NotImplementedError
elif arrow_temporal_supported:
self.frame_scalar_exc = None
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
self.frame_scalar_exc = pa.ArrowNotImplementedError
else:
self.frame_scalar_exc = None
if (
all_arithmetic_operators == "__rpow__"
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and not pa_version_under6p0
):
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL "
f"for {pa_dtype}"
)
)
)
elif arrow_temporal_supported:
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
f"{all_arithmetic_operators} not supported between"
f"pd.NA and {pa_dtype} Python scalar"
),
)
)
elif (
all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"}
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and not pa_version_under6p0
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
)

self.frame_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)

if all_arithmetic_operators == "__floordiv__" and pa.types.is_integer(pa_dtype):
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
# not upcast
Expand All @@ -1195,37 +1190,11 @@ def test_arith_series_with_array(
):
pa_dtype = data.dtype.pyarrow_dtype

arrow_temporal_supported = not pa_version_under8p0 and (
all_arithmetic_operators in ("__add__", "__radd__")
and pa.types.is_duration(pa_dtype)
or all_arithmetic_operators in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
self.series_array_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)
if all_arithmetic_operators in {
"__mod__",
"__rmod__",
}:
self.series_array_exc = NotImplementedError
elif arrow_temporal_supported:
self.series_array_exc = None
elif not (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype)):
self.series_array_exc = pa.ArrowNotImplementedError
else:
self.series_array_exc = None

if (
all_arithmetic_operators == "__rpow__"
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and not pa_version_under6p0
):
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"GH 29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL "
f"for {pa_dtype}"
)
)
)
elif (
all_arithmetic_operators
in (
"__sub__",
Expand All @@ -1243,32 +1212,17 @@ def test_arith_series_with_array(
),
)
)
elif arrow_temporal_supported:
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
f"{all_arithmetic_operators} not supported between"
f"pd.NA and {pa_dtype} Python scalar"
),
)
)
elif (
all_arithmetic_operators in {"__rtruediv__", "__rfloordiv__"}
and (pa.types.is_floating(pa_dtype) or pa.types.is_integer(pa_dtype))
and not pa_version_under6p0
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason="divide by 0",
)
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)

op_name = all_arithmetic_operators
ser = pd.Series(data)
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
# since ser.iloc[0] is a python scalar
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))

if pa.types.is_floating(pa_dtype) or (
pa.types.is_integer(pa_dtype) and all_arithmetic_operators != "__truediv__"
):
Expand Down Expand Up @@ -1364,6 +1318,25 @@ def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]])
def test_quantile(data, interpolation, quantile, request):
pa_dtype = data.dtype.pyarrow_dtype

data = data.take([0, 0, 0])
ser = pd.Series(data)

if (
pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
or pa.types.is_boolean(pa_dtype)
):
# For string, bytes, and bool, we don't *expect* to have quantile work
# Note this matches the non-pyarrow behavior
if pa_version_under7p0:
msg = r"Function quantile has no kernel matching input types \(.*\)"
else:
msg = r"Function 'quantile' has no kernel matching input types \(.*\)"
with pytest.raises(pa.ArrowNotImplementedError, match=msg):
ser.quantile(q=quantile, interpolation=interpolation)
return

if not (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype)):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -1398,11 +1371,7 @@ def test_quantile(data, interpolation, quantile, request):
)
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if (
pa.types.is_temporal(pa_dtype)
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
):
if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down