Skip to content

Commit 0d853e7

Browse files
authored
BUG: Series.__mul__ for pyarrow strings (#56368)
* BUG: Series.__mul__ for pyarrow strings * Fix existing tests * Another test
1 parent d36fb98 commit 0d853e7

File tree

4 files changed

+43
-17
lines changed

4 files changed

+43
-17
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ Strings
576576
^^^^^^^
577577
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
578578
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
579+
- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`)
579580
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
580581

581582
Interval

pandas/core/arrays/arrow/array.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -668,16 +668,22 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
668668
pa_type = self._pa_array.type
669669
other = self._box_pa(other)
670670

671-
if (pa.types.is_string(pa_type) or pa.types.is_binary(pa_type)) and op in [
672-
operator.add,
673-
roperator.radd,
674-
]:
675-
sep = pa.scalar("", type=pa_type)
676-
if op is operator.add:
677-
result = pc.binary_join_element_wise(self._pa_array, other, sep)
678-
else:
679-
result = pc.binary_join_element_wise(other, self._pa_array, sep)
680-
return type(self)(result)
671+
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
672+
if op in [operator.add, roperator.radd, operator.mul, roperator.rmul]:
673+
sep = pa.scalar("", type=pa_type)
674+
if op is operator.add:
675+
result = pc.binary_join_element_wise(self._pa_array, other, sep)
676+
elif op is roperator.radd:
677+
result = pc.binary_join_element_wise(other, self._pa_array, sep)
678+
else:
679+
if not (
680+
isinstance(other, pa.Scalar) and pa.types.is_integer(other.type)
681+
):
682+
raise TypeError("Can only string multiply by an integer.")
683+
result = pc.binary_join_element_wise(
684+
*([self._pa_array] * other.as_py()), sep
685+
)
686+
return type(self)(result)
681687

682688
if (
683689
isinstance(other, pa.Scalar)

pandas/tests/arrays/string_/test_string.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,7 @@ def test_add_sequence(dtype):
176176
tm.assert_extension_array_equal(result, expected)
177177

178178

179-
def test_mul(dtype, request, arrow_string_storage):
180-
if dtype.storage in arrow_string_storage:
181-
reason = "unsupported operand type(s) for *: 'ArrowStringArray' and 'int'"
182-
mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason)
183-
request.applymarker(mark)
184-
179+
def test_mul(dtype):
185180
a = pd.array(["a", "b", None], dtype=dtype)
186181
result = a * 2
187182
expected = pd.array(["aa", "bb", None], dtype=dtype)

pandas/tests/extension/test_arrow.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -965,8 +965,16 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
965965
def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
966966
pa_dtype = data.dtype.pyarrow_dtype
967967

968-
if all_arithmetic_operators == "__rmod__" and (pa.types.is_binary(pa_dtype)):
968+
if all_arithmetic_operators == "__rmod__" and pa.types.is_binary(pa_dtype):
969969
pytest.skip("Skip testing Python string formatting")
970+
elif all_arithmetic_operators in ("__rmul__", "__mul__") and (
971+
pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype)
972+
):
973+
request.applymarker(
974+
pytest.mark.xfail(
975+
raises=TypeError, reason="Can only string multiply by an integer."
976+
)
977+
)
970978

971979
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
972980
if mark is not None:
@@ -981,6 +989,14 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
981989
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
982990
):
983991
pytest.skip("Skip testing Python string formatting")
992+
elif all_arithmetic_operators in ("__rmul__", "__mul__") and (
993+
pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype)
994+
):
995+
request.applymarker(
996+
pytest.mark.xfail(
997+
raises=TypeError, reason="Can only string multiply by an integer."
998+
)
999+
)
9841000

9851001
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
9861002
if mark is not None:
@@ -1004,6 +1020,14 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
10041020
),
10051021
)
10061022
)
1023+
elif all_arithmetic_operators in ("__rmul__", "__mul__") and (
1024+
pa.types.is_binary(pa_dtype) or pa.types.is_string(pa_dtype)
1025+
):
1026+
request.applymarker(
1027+
pytest.mark.xfail(
1028+
raises=TypeError, reason="Can only string multiply by an integer."
1029+
)
1030+
)
10071031

10081032
mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
10091033
if mark is not None:

0 commit comments

Comments
 (0)