Skip to content

Commit 2fdb16b

Browse files
String dtype: implement sum reduction (#59853)
1 parent 1cdd20e commit 2fdb16b

File tree

16 files changed

+120
-151
lines changed

16 files changed

+120
-151
lines changed

doc/source/whatsnew/v2.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ enhancement1
3232
Other enhancements
3333
^^^^^^^^^^^^^^^^^^
3434

35-
-
35+
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
3636
-
3737

3838
.. ---------------------------------------------------------------------------

pandas/core/array_algos/masked_reductions.py

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def _reductions(
6262
):
6363
return libmissing.NA
6464

65+
if values.dtype == np.dtype(object):
66+
# object dtype does not support `where` without passing an initial
67+
values = values[~mask]
68+
return func(values, axis=axis, **kwargs)
6569
return func(values, where=~mask, axis=axis, **kwargs)
6670

6771

pandas/core/arrays/arrow/array.py

+32
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
unpack_tuple_and_ellipses,
6969
validate_indices,
7070
)
71+
from pandas.core.nanops import check_below_min_count
7172
from pandas.core.strings.base import BaseStringArrayMethods
7273

7374
from pandas.io._util import _arrow_dtype_mapping
@@ -1705,6 +1706,37 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
17051706
denominator = pc.sqrt_checked(pc.count(self._pa_array))
17061707
return pc.divide_checked(numerator, denominator)
17071708

1709+
elif name == "sum" and (
1710+
pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type)
1711+
):
1712+
1713+
def pyarrow_meth(data, skip_nulls, min_count=0): # type: ignore[misc]
1714+
mask = pc.is_null(data) if data.null_count > 0 else None
1715+
if skip_nulls:
1716+
if min_count > 0 and check_below_min_count(
1717+
(len(data),),
1718+
None if mask is None else mask.to_numpy(),
1719+
min_count,
1720+
):
1721+
return pa.scalar(None, type=data.type)
1722+
if data.null_count > 0:
1723+
# binary_join returns null if there is any null ->
1724+
# have to filter out any nulls
1725+
data = data.filter(pc.invert(mask))
1726+
else:
1727+
if mask is not None or check_below_min_count(
1728+
(len(data),), None, min_count
1729+
):
1730+
return pa.scalar(None, type=data.type)
1731+
1732+
if pa.types.is_large_string(data.type):
1733+
# binary_join only supports string, not large_string
1734+
data = data.cast(pa.string())
1735+
data_list = pa.ListArray.from_arrays(
1736+
[0, len(data)], data.combine_chunks()
1737+
)[0]
1738+
return pc.binary_join(data_list, "")
1739+
17081740
else:
17091741
pyarrow_name = {
17101742
"median": "quantile",

pandas/core/arrays/string_.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -812,8 +812,8 @@ def _reduce(
812812
else:
813813
return nanops.nanall(self._ndarray, skipna=skipna)
814814

815-
if name in ["min", "max"]:
816-
result = getattr(self, name)(skipna=skipna, axis=axis)
815+
if name in ["min", "max", "sum"]:
816+
result = getattr(self, name)(skipna=skipna, axis=axis, **kwargs)
817817
if keepdims:
818818
return self._from_sequence([result], dtype=self.dtype)
819819
return result
@@ -840,6 +840,20 @@ def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
840840
)
841841
return self._wrap_reduction_result(axis, result)
842842

843+
def sum(
844+
self,
845+
*,
846+
axis: AxisInt | None = None,
847+
skipna: bool = True,
848+
min_count: int = 0,
849+
**kwargs,
850+
) -> Scalar:
851+
nv.validate_sum((), kwargs)
852+
result = masked_reductions.sum(
853+
values=self._ndarray, mask=self.isna(), skipna=skipna
854+
)
855+
return self._wrap_reduction_result(axis, result)
856+
843857
def value_counts(self, dropna: bool = True) -> Series:
844858
from pandas.core.algorithms import value_counts_internal as value_counts
845859

pandas/core/arrays/string_arrow.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,11 @@ def _reduce(
435435
return result.astype(np.bool_)
436436
return result
437437

438-
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
438+
if name in ("min", "max", "sum", "argmin", "argmax"):
439+
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
440+
else:
441+
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")
442+
439443
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
440444
return self._convert_int_result(result)
441445
elif isinstance(result, pa.Array):

pandas/tests/apply/test_frame_apply.py

-10
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import numpy as np
55
import pytest
66

7-
from pandas._config import using_string_dtype
8-
9-
from pandas.compat import HAS_PYARROW
10-
117
from pandas.core.dtypes.dtypes import CategoricalDtype
128

139
import pandas as pd
@@ -1218,7 +1214,6 @@ def test_agg_with_name_as_column_name():
12181214
tm.assert_series_equal(result, expected)
12191215

12201216

1221-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
12221217
def test_agg_multiple_mixed():
12231218
# GH 20909
12241219
mdf = DataFrame(
@@ -1247,9 +1242,6 @@ def test_agg_multiple_mixed():
12471242
tm.assert_frame_equal(result, expected)
12481243

12491244

1250-
@pytest.mark.xfail(
1251-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
1252-
)
12531245
def test_agg_multiple_mixed_raises():
12541246
# GH 20909
12551247
mdf = DataFrame(
@@ -1347,7 +1339,6 @@ def test_named_agg_reduce_axis1_raises(float_frame):
13471339
float_frame.agg(row1=(name1, "sum"), row2=(name2, "max"), axis=axis)
13481340

13491341

1350-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
13511342
def test_nuiscance_columns():
13521343
# GH 15015
13531344
df = DataFrame(
@@ -1524,7 +1515,6 @@ def test_apply_datetime_tz_issue(engine, request):
15241515
tm.assert_series_equal(result, expected)
15251516

15261517

1527-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
15281518
@pytest.mark.parametrize("df", [DataFrame({"A": ["a", None], "B": ["c", "d"]})])
15291519
@pytest.mark.parametrize("method", ["min", "max", "sum"])
15301520
def test_mixed_column_raises(df, method, using_infer_string):

pandas/tests/apply/test_invalid_arg.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
import numpy as np
1313
import pytest
1414

15-
from pandas._config import using_string_dtype
16-
17-
from pandas.compat import HAS_PYARROW
1815
from pandas.errors import SpecificationError
1916

2017
from pandas import (
@@ -212,10 +209,6 @@ def transform(row):
212209
data.apply(transform, axis=1)
213210

214211

215-
# we should raise a proper TypeError instead of propagating the pyarrow error
216-
@pytest.mark.xfail(
217-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
218-
)
219212
@pytest.mark.parametrize(
220213
"df, func, expected",
221214
tm.get_cython_table_params(
@@ -225,21 +218,25 @@ def transform(row):
225218
def test_agg_cython_table_raises_frame(df, func, expected, axis, using_infer_string):
226219
# GH 21224
227220
if using_infer_string:
228-
import pyarrow as pa
221+
if df.dtypes.iloc[0].storage == "pyarrow":
222+
import pyarrow as pa
229223

230-
expected = (expected, pa.lib.ArrowNotImplementedError)
224+
# TODO(infer_string)
225+
# should raise a proper TypeError instead of propagating the pyarrow error
231226

232-
msg = "can't multiply sequence by non-int of type 'str'|has no kernel"
227+
expected = (expected, pa.lib.ArrowNotImplementedError)
228+
else:
229+
expected = (expected, NotImplementedError)
230+
231+
msg = (
232+
"can't multiply sequence by non-int of type 'str'|has no kernel|cannot perform"
233+
)
233234
warn = None if isinstance(func, str) else FutureWarning
234235
with pytest.raises(expected, match=msg):
235236
with tm.assert_produces_warning(warn, match="using DataFrame.cumprod"):
236237
df.agg(func, axis=axis)
237238

238239

239-
# we should raise a proper TypeError instead of propagating the pyarrow error
240-
@pytest.mark.xfail(
241-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
242-
)
243240
@pytest.mark.parametrize(
244241
"series, func, expected",
245242
chain(
@@ -263,11 +260,15 @@ def test_agg_cython_table_raises_series(series, func, expected, using_infer_stri
263260
msg = r"Cannot convert \['a' 'b' 'c'\] to numeric"
264261

265262
if using_infer_string:
266-
import pyarrow as pa
267-
268-
expected = (expected, pa.lib.ArrowNotImplementedError)
269-
270-
msg = msg + "|does not support|has no kernel"
263+
if series.dtype.storage == "pyarrow":
264+
import pyarrow as pa
265+
266+
# TODO(infer_string)
267+
# should raise a proper TypeError instead of propagating the pyarrow error
268+
expected = (expected, pa.lib.ArrowNotImplementedError)
269+
else:
270+
expected = (expected, NotImplementedError)
271+
msg = msg + "|does not support|has no kernel|Cannot perform|cannot perform"
271272
warn = None if isinstance(func, str) else FutureWarning
272273

273274
with pytest.raises(expected, match=msg):

pandas/tests/arrays/string_/test_string.py

-2
Original file line numberDiff line numberDiff line change
@@ -444,14 +444,12 @@ def test_astype_float(dtype, any_float_dtype):
444444
tm.assert_series_equal(result, expected)
445445

446446

447-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
448447
def test_reduce(skipna, dtype):
449448
arr = pd.Series(["a", "b", "c"], dtype=dtype)
450449
result = arr.sum(skipna=skipna)
451450
assert result == "abc"
452451

453452

454-
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
455453
def test_reduce_missing(skipna, dtype):
456454
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
457455
result = arr.sum(skipna=skipna)

pandas/tests/extension/test_arrow.py

+4-21
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,11 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
461461
pass
462462
else:
463463
return False
464+
elif pa.types.is_binary(pa_dtype) and op_name == "sum":
465+
return False
464466
elif (
465467
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
466468
) and op_name in [
467-
"sum",
468469
"mean",
469470
"median",
470471
"prod",
@@ -563,6 +564,8 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
563564
cmp_dtype = "float64[pyarrow]"
564565
elif op_name in ["sum", "prod"] and pa.types.is_boolean(pa_type):
565566
cmp_dtype = "uint64[pyarrow]"
567+
elif op_name == "sum" and pa.types.is_string(pa_type):
568+
cmp_dtype = arr.dtype
566569
else:
567570
cmp_dtype = {
568571
"i": "int64[pyarrow]",
@@ -594,26 +597,6 @@ def test_median_not_approximate(self, typ):
594597
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
595598
assert result == 1.5
596599

597-
def test_in_numeric_groupby(self, data_for_grouping):
598-
dtype = data_for_grouping.dtype
599-
if is_string_dtype(dtype):
600-
df = pd.DataFrame(
601-
{
602-
"A": [1, 1, 2, 2, 3, 3, 1, 4],
603-
"B": data_for_grouping,
604-
"C": [1, 1, 1, 1, 1, 1, 1, 1],
605-
}
606-
)
607-
608-
expected = pd.Index(["C"])
609-
msg = re.escape(f"agg function failed [how->sum,dtype->{dtype}")
610-
with pytest.raises(TypeError, match=msg):
611-
df.groupby("A").sum()
612-
result = df.groupby("A").sum(numeric_only=True).columns
613-
tm.assert_index_equal(result, expected)
614-
else:
615-
super().test_in_numeric_groupby(data_for_grouping)
616-
617600
def test_construct_from_string_own_name(self, dtype, request):
618601
pa_dtype = dtype.pyarrow_dtype
619602
if pa.types.is_decimal(pa_dtype):

pandas/tests/extension/test_string.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _get_expected_exception(
188188

189189
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
190190
return (
191-
op_name in ["min", "max"]
191+
op_name in ["min", "max", "sum"]
192192
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
193193
and op_name in ("any", "all")
194194
)

0 commit comments

Comments
 (0)