Skip to content

Commit 4f189a4

Browse files
[backport 2.3.x] String dtype: implement sum reduction (#59853) (#60157)
String dtype: implement sum reduction (#59853) (cherry picked from commit 2fdb16b)
1 parent e620e9d commit 4f189a4

File tree

15 files changed

+121
-150
lines changed

15 files changed

+121
-150
lines changed

doc/source/whatsnew/v2.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ enhancement1
3232
Other enhancements
3333
^^^^^^^^^^^^^^^^^^
3434

35-
-
35+
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
3636
-
3737

3838
.. ---------------------------------------------------------------------------

pandas/core/array_algos/masked_reductions.py

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def _reductions(
6262
):
6363
return libmissing.NA
6464

65+
if values.dtype == np.dtype(object):
66+
# object dtype does not support `where` without passing an initial
67+
values = values[~mask]
68+
return func(values, axis=axis, **kwargs)
6569
return func(values, where=~mask, axis=axis, **kwargs)
6670

6771

pandas/core/arrays/arrow/array.py

+32
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
unpack_tuple_and_ellipses,
7070
validate_indices,
7171
)
72+
from pandas.core.nanops import check_below_min_count
7273
from pandas.core.strings.base import BaseStringArrayMethods
7374

7475
from pandas.io._util import _arrow_dtype_mapping
@@ -1694,6 +1695,37 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
16941695
denominator = pc.sqrt_checked(pc.count(self._pa_array))
16951696
return pc.divide_checked(numerator, denominator)
16961697

1698+
elif name == "sum" and (
1699+
pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type)
1700+
):
1701+
1702+
def pyarrow_meth(data, skip_nulls, min_count=0): # type: ignore[misc]
1703+
mask = pc.is_null(data) if data.null_count > 0 else None
1704+
if skip_nulls:
1705+
if min_count > 0 and check_below_min_count(
1706+
(len(data),),
1707+
None if mask is None else mask.to_numpy(),
1708+
min_count,
1709+
):
1710+
return pa.scalar(None, type=data.type)
1711+
if data.null_count > 0:
1712+
# binary_join returns null if there is any null ->
1713+
# have to filter out any nulls
1714+
data = data.filter(pc.invert(mask))
1715+
else:
1716+
if mask is not None or check_below_min_count(
1717+
(len(data),), None, min_count
1718+
):
1719+
return pa.scalar(None, type=data.type)
1720+
1721+
if pa.types.is_large_string(data.type):
1722+
# binary_join only supports string, not large_string
1723+
data = data.cast(pa.string())
1724+
data_list = pa.ListArray.from_arrays(
1725+
[0, len(data)], data.combine_chunks()
1726+
)[0]
1727+
return pc.binary_join(data_list, "")
1728+
16971729
else:
16981730
pyarrow_name = {
16991731
"median": "quantile",

pandas/core/arrays/string_.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -812,8 +812,8 @@ def _reduce(
812812
else:
813813
return nanops.nanall(self._ndarray, skipna=skipna)
814814

815-
if name in ["min", "max"]:
816-
result = getattr(self, name)(skipna=skipna, axis=axis)
815+
if name in ["min", "max", "sum"]:
816+
result = getattr(self, name)(skipna=skipna, axis=axis, **kwargs)
817817
if keepdims:
818818
return self._from_sequence([result], dtype=self.dtype)
819819
return result
@@ -839,6 +839,20 @@ def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
839839
)
840840
return self._wrap_reduction_result(axis, result)
841841

842+
def sum(
843+
self,
844+
*,
845+
axis: AxisInt | None = None,
846+
skipna: bool = True,
847+
min_count: int = 0,
848+
**kwargs,
849+
) -> Scalar:
850+
nv.validate_sum((), kwargs)
851+
result = masked_reductions.sum(
852+
values=self._ndarray, mask=self.isna(), skipna=skipna
853+
)
854+
return self._wrap_reduction_result(axis, result)
855+
842856
def value_counts(self, dropna: bool = True) -> Series:
843857
from pandas.core.algorithms import value_counts_internal as value_counts
844858

pandas/core/arrays/string_arrow.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,11 @@ def _reduce(
430430
return result.astype(np.bool_)
431431
return result
432432

433-
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
433+
if name in ("min", "max", "sum", "argmin", "argmax"):
434+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
435+
else:
436+
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
437+
434438
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
435439
return self._convert_int_result(result)
436440
elif isinstance(result, pa.Array):

pandas/tests/apply/test_frame_apply.py

-10
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import numpy as np
55
import pytest
66

7-
from pandas._config import using_string_dtype
8-
9-
from pandas.compat import HAS_PYARROW
10-
117
from pandas.core.dtypes.dtypes import CategoricalDtype
128

139
import pandas as pd
@@ -1173,7 +1169,6 @@ def test_agg_with_name_as_column_name():
11731169
tm.assert_series_equal(result, expected)
11741170

11751171

1176-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
11771172
def test_agg_multiple_mixed():
11781173
# GH 20909
11791174
mdf = DataFrame(
@@ -1202,9 +1197,6 @@ def test_agg_multiple_mixed():
12021197
tm.assert_frame_equal(result, expected)
12031198

12041199

1205-
@pytest.mark.xfail(
1206-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
1207-
)
12081200
def test_agg_multiple_mixed_raises():
12091201
# GH 20909
12101202
mdf = DataFrame(
@@ -1294,7 +1286,6 @@ def test_agg_reduce(axis, float_frame):
12941286
tm.assert_frame_equal(result, expected)
12951287

12961288

1297-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
12981289
def test_nuiscance_columns():
12991290
# GH 15015
13001291
df = DataFrame(
@@ -1471,7 +1462,6 @@ def test_apply_datetime_tz_issue(engine, request):
14711462
tm.assert_series_equal(result, expected)
14721463

14731464

1474-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
14751465
@pytest.mark.parametrize("df", [DataFrame({"A": ["a", None], "B": ["c", "d"]})])
14761466
@pytest.mark.parametrize("method", ["min", "max", "sum"])
14771467
def test_mixed_column_raises(df, method, using_infer_string):

pandas/tests/apply/test_invalid_arg.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
import numpy as np
1313
import pytest
1414

15-
from pandas._config import using_string_dtype
16-
17-
from pandas.compat import HAS_PYARROW
1815
from pandas.errors import SpecificationError
1916

2017
from pandas import (
@@ -212,10 +209,6 @@ def transform(row):
212209
data.apply(transform, axis=1)
213210

214211

215-
# we should raise a proper TypeError instead of propagating the pyarrow error
216-
@pytest.mark.xfail(
217-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
218-
)
219212
@pytest.mark.parametrize(
220213
"df, func, expected",
221214
tm.get_cython_table_params(
@@ -225,21 +218,25 @@ def transform(row):
225218
def test_agg_cython_table_raises_frame(df, func, expected, axis, using_infer_string):
226219
# GH 21224
227220
if using_infer_string:
228-
import pyarrow as pa
221+
if df.dtypes.iloc[0].storage == "pyarrow":
222+
import pyarrow as pa
229223

230-
expected = (expected, pa.lib.ArrowNotImplementedError)
224+
# TODO(infer_string)
225+
# should raise a proper TypeError instead of propagating the pyarrow error
231226

232-
msg = "can't multiply sequence by non-int of type 'str'|has no kernel"
227+
expected = (expected, pa.lib.ArrowNotImplementedError)
228+
else:
229+
expected = (expected, NotImplementedError)
230+
231+
msg = (
232+
"can't multiply sequence by non-int of type 'str'|has no kernel|cannot perform"
233+
)
233234
warn = None if isinstance(func, str) else FutureWarning
234235
with pytest.raises(expected, match=msg):
235236
with tm.assert_produces_warning(warn, match="using DataFrame.cumprod"):
236237
df.agg(func, axis=axis)
237238

238239

239-
# we should raise a proper TypeError instead of propagating the pyarrow error
240-
@pytest.mark.xfail(
241-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
242-
)
243240
@pytest.mark.parametrize(
244241
"series, func, expected",
245242
chain(
@@ -263,11 +260,15 @@ def test_agg_cython_table_raises_series(series, func, expected, using_infer_stri
263260
msg = r"Cannot convert \['a' 'b' 'c'\] to numeric"
264261

265262
if using_infer_string:
266-
import pyarrow as pa
267-
268-
expected = (expected, pa.lib.ArrowNotImplementedError)
269-
270-
msg = msg + "|does not support|has no kernel"
263+
if series.dtype.storage == "pyarrow":
264+
import pyarrow as pa
265+
266+
# TODO(infer_string)
267+
# should raise a proper TypeError instead of propagating the pyarrow error
268+
expected = (expected, pa.lib.ArrowNotImplementedError)
269+
else:
270+
expected = (expected, NotImplementedError)
271+
msg = msg + "|does not support|has no kernel|Cannot perform|cannot perform"
271272
warn = None if isinstance(func, str) else FutureWarning
272273

273274
with pytest.raises(expected, match=msg):

pandas/tests/arrays/string_/test_string.py

-2
Original file line numberDiff line numberDiff line change
@@ -444,15 +444,13 @@ def test_astype_float(dtype, any_float_dtype):
444444

445445

446446
@pytest.mark.parametrize("skipna", [True, False])
447-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
448447
def test_reduce(skipna, dtype):
449448
arr = pd.Series(["a", "b", "c"], dtype=dtype)
450449
result = arr.sum(skipna=skipna)
451450
assert result == "abc"
452451

453452

454453
@pytest.mark.parametrize("skipna", [True, False])
455-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
456454
def test_reduce_missing(skipna, dtype):
457455
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
458456
result = arr.sum(skipna=skipna)

pandas/tests/extension/test_arrow.py

+5-21
Original file line numberDiff line numberDiff line change
@@ -459,10 +459,11 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
459459
pass
460460
else:
461461
return False
462+
elif pa.types.is_binary(pa_dtype) and op_name == "sum":
463+
return False
462464
elif (
463465
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
464466
) and op_name in [
465-
"sum",
466467
"mean",
467468
"median",
468469
"prod",
@@ -553,6 +554,7 @@ def test_reduce_series_boolean(
553554
return super().test_reduce_series_boolean(data, all_boolean_reductions, skipna)
554555

555556
def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
557+
pa_type = arr._pa_array.type
556558
if op_name in ["max", "min"]:
557559
cmp_dtype = arr.dtype
558560
elif arr.dtype.name == "decimal128(7, 3)[pyarrow]":
@@ -562,6 +564,8 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
562564
cmp_dtype = "float64[pyarrow]"
563565
elif op_name in ["median", "var", "std", "mean", "skew"]:
564566
cmp_dtype = "float64[pyarrow]"
567+
elif op_name == "sum" and pa.types.is_string(pa_type):
568+
cmp_dtype = arr.dtype
565569
else:
566570
cmp_dtype = {
567571
"i": "int64[pyarrow]",
@@ -585,26 +589,6 @@ def test_median_not_approximate(self, typ):
585589
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
586590
assert result == 1.5
587591

588-
def test_in_numeric_groupby(self, data_for_grouping):
589-
dtype = data_for_grouping.dtype
590-
if is_string_dtype(dtype):
591-
df = pd.DataFrame(
592-
{
593-
"A": [1, 1, 2, 2, 3, 3, 1, 4],
594-
"B": data_for_grouping,
595-
"C": [1, 1, 1, 1, 1, 1, 1, 1],
596-
}
597-
)
598-
599-
expected = pd.Index(["C"])
600-
msg = re.escape(f"agg function failed [how->sum,dtype->{dtype}")
601-
with pytest.raises(TypeError, match=msg):
602-
df.groupby("A").sum()
603-
result = df.groupby("A").sum(numeric_only=True).columns
604-
tm.assert_index_equal(result, expected)
605-
else:
606-
super().test_in_numeric_groupby(data_for_grouping)
607-
608592
def test_construct_from_string_own_name(self, dtype, request):
609593
pa_dtype = dtype.pyarrow_dtype
610594
if pa.types.is_decimal(pa_dtype):

pandas/tests/extension/test_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _get_expected_exception(
191191

192192
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
193193
return (
194-
op_name in ["min", "max"]
194+
op_name in ["min", "max", "sum"]
195195
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
196196
and op_name in ("any", "all")
197197
)

0 commit comments

Comments
 (0)