Skip to content

Commit bedd8f0

Browse files
authored
ENH/TST: Add Reduction tests for ArrowExtensionArray (#47730)
1 parent 8f04a8e commit bedd8f0

File tree

3 files changed

+159
-3
lines changed

3 files changed

+159
-3
lines changed

pandas/core/arrays/arrow/array.py

+63
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,69 @@ def _concat_same_type(
688688
arr = pa.chunked_array(chunks)
689689
return cls(arr)
690690

691+
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
692+
"""
693+
Return a scalar result of performing the reduction operation.
694+
695+
Parameters
696+
----------
697+
name : str
698+
Name of the function, supported values are:
699+
{ any, all, min, max, sum, mean, median, prod,
700+
std, var, sem, kurt, skew }.
701+
skipna : bool, default True
702+
If True, skip NaN values.
703+
**kwargs
704+
Additional keyword arguments passed to the reduction function.
705+
Currently, `ddof` is the only supported kwarg.
706+
707+
Returns
708+
-------
709+
scalar
710+
711+
Raises
712+
------
713+
TypeError : subclass does not define reductions
714+
"""
715+
if name == "sem":
716+
717+
def pyarrow_meth(data, skipna, **kwargs):
718+
numerator = pc.stddev(data, skip_nulls=skipna, **kwargs)
719+
denominator = pc.sqrt_checked(
720+
pc.subtract_checked(
721+
pc.count(self._data, skip_nulls=skipna), kwargs["ddof"]
722+
)
723+
)
724+
return pc.divide_checked(numerator, denominator)
725+
726+
else:
727+
pyarrow_name = {
728+
"median": "approximate_median",
729+
"prod": "product",
730+
"std": "stddev",
731+
"var": "variance",
732+
}.get(name, name)
733+
# error: Incompatible types in assignment
734+
# (expression has type "Optional[Any]", variable has type
735+
# "Callable[[Any, Any, KwArg(Any)], Any]")
736+
pyarrow_meth = getattr(pc, pyarrow_name, None) # type: ignore[assignment]
737+
if pyarrow_meth is None:
738+
# Let ExtensionArray._reduce raise the TypeError
739+
return super()._reduce(name, skipna=skipna, **kwargs)
740+
try:
741+
result = pyarrow_meth(self._data, skip_nulls=skipna, **kwargs)
742+
except (AttributeError, NotImplementedError, TypeError) as err:
743+
msg = (
744+
f"'{type(self).__name__}' with dtype {self.dtype} "
745+
f"does not support reduction '{name}' with pyarrow "
746+
f"version {pa.__version__}. '{name}' may be supported by "
747+
f"upgrading pyarrow."
748+
)
749+
raise TypeError(msg) from err
750+
if pc.is_null(result).as_py():
751+
return self.dtype.na_value
752+
return result.as_py()
753+
691754
def __setitem__(self, key: int | slice | np.ndarray, value: Any) -> None:
692755
"""Set one or more values inplace.
693756

pandas/tests/arrays/string_/test_string.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import numpy as np
66
import pytest
77

8-
from pandas.compat import pa_version_under2p0
8+
from pandas.compat import (
9+
pa_version_under2p0,
10+
pa_version_under6p0,
11+
)
912
from pandas.errors import PerformanceWarning
1013
import pandas.util._test_decorators as td
1114

@@ -375,7 +378,7 @@ def test_reduce_missing(skipna, dtype):
375378
@pytest.mark.parametrize("method", ["min", "max"])
376379
@pytest.mark.parametrize("skipna", [True, False])
377380
def test_min_max(method, skipna, dtype, request):
378-
if dtype.storage == "pyarrow":
381+
if dtype.storage == "pyarrow" and pa_version_under6p0:
379382
reason = "'ArrowStringArray' object has no attribute 'max'"
380383
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
381384
request.node.add_marker(mark)
@@ -392,7 +395,7 @@ def test_min_max(method, skipna, dtype, request):
392395
@pytest.mark.parametrize("method", ["min", "max"])
393396
@pytest.mark.parametrize("box", [pd.Series, pd.array])
394397
def test_min_max_numpy(method, box, dtype, request):
395-
if dtype.storage == "pyarrow":
398+
if dtype.storage == "pyarrow" and (pa_version_under6p0 or box is pd.array):
396399
if box is pd.array:
397400
reason = "'<=' not supported between instances of 'str' and 'NoneType'"
398401
else:

pandas/tests/extension/test_arrow.py

+90
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pandas.compat import (
2525
pa_version_under2p0,
2626
pa_version_under3p0,
27+
pa_version_under6p0,
2728
pa_version_under8p0,
2829
)
2930

@@ -303,6 +304,95 @@ def test_loc_iloc_frame_single_dtype(self, request, using_array_manager, data):
303304
super().test_loc_iloc_frame_single_dtype(data)
304305

305306

307+
class TestBaseNumericReduce(base.BaseNumericReduceTests):
308+
def check_reduce(self, ser, op_name, skipna):
309+
pa_dtype = ser.dtype.pyarrow_dtype
310+
result = getattr(ser, op_name)(skipna=skipna)
311+
if pa.types.is_boolean(pa_dtype):
312+
# Can't convert if ser contains NA
313+
pytest.skip(
314+
"pandas boolean data with NA does not fully support all reductions"
315+
)
316+
elif pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
317+
ser = ser.astype("Float64")
318+
expected = getattr(ser, op_name)(skipna=skipna)
319+
tm.assert_almost_equal(result, expected)
320+
321+
@pytest.mark.parametrize("skipna", [True, False])
322+
def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
323+
pa_dtype = data.dtype.pyarrow_dtype
324+
xfail_mark = pytest.mark.xfail(
325+
raises=TypeError,
326+
reason=(
327+
f"{all_numeric_reductions} is not implemented in "
328+
f"pyarrow={pa.__version__} for {pa_dtype}"
329+
),
330+
)
331+
if all_numeric_reductions in {"skew", "kurt"}:
332+
request.node.add_marker(xfail_mark)
333+
elif (
334+
all_numeric_reductions in {"median", "var", "std", "prod", "max", "min"}
335+
and pa_version_under6p0
336+
):
337+
request.node.add_marker(xfail_mark)
338+
elif all_numeric_reductions in {"sum", "mean"} and pa_version_under2p0:
339+
request.node.add_marker(xfail_mark)
340+
elif (
341+
all_numeric_reductions in {"sum", "mean"}
342+
and skipna is False
343+
and pa_version_under6p0
344+
and (pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype))
345+
):
346+
request.node.add_marker(
347+
pytest.mark.xfail(
348+
raises=AssertionError,
349+
reason=(
350+
f"{all_numeric_reductions} with skip_nulls={skipna} did not "
351+
f"return NA for {pa_dtype} with pyarrow={pa.__version__}"
352+
),
353+
)
354+
)
355+
elif not (
356+
pa.types.is_integer(pa_dtype)
357+
or pa.types.is_floating(pa_dtype)
358+
or pa.types.is_boolean(pa_dtype)
359+
) and not (
360+
all_numeric_reductions in {"min", "max"}
361+
and (pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
362+
):
363+
request.node.add_marker(xfail_mark)
364+
elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
365+
"std",
366+
"var",
367+
"median",
368+
}:
369+
request.node.add_marker(xfail_mark)
370+
super().test_reduce_series(data, all_numeric_reductions, skipna)
371+
372+
373+
class TestBaseBooleanReduce(base.BaseBooleanReduceTests):
374+
@pytest.mark.parametrize("skipna", [True, False])
375+
def test_reduce_series(
376+
self, data, all_boolean_reductions, skipna, na_value, request
377+
):
378+
pa_dtype = data.dtype.pyarrow_dtype
379+
xfail_mark = pytest.mark.xfail(
380+
raises=TypeError,
381+
reason=(
382+
f"{all_boolean_reductions} is not implemented in "
383+
f"pyarrow={pa.__version__} for {pa_dtype}"
384+
),
385+
)
386+
if not pa.types.is_boolean(pa_dtype):
387+
request.node.add_marker(xfail_mark)
388+
elif pa_version_under3p0:
389+
request.node.add_marker(xfail_mark)
390+
op_name = all_boolean_reductions
391+
s = pd.Series(data)
392+
result = getattr(s, op_name)(skipna=skipna)
393+
assert result is (op_name == "any")
394+
395+
306396
class TestBaseGroupby(base.BaseGroupbyTests):
307397
def test_groupby_agg_extension(self, data_for_grouping, request):
308398
tz = getattr(data_for_grouping.dtype.pyarrow_dtype, "tz", None)

0 commit comments

Comments
 (0)