Skip to content

Commit 2b3c1ec

Browse files
author
Rohan Jain
committed
allow repeat count to be a series
1 parent d77d5e5 commit 2b3c1ec

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

pandas/core/arrays/arrow/array.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -695,22 +695,21 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
695695
other = self._box_pa(other)
696696

697697
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
698-
if op in [operator.add, roperator.radd, operator.mul, roperator.rmul]:
698+
if op in [operator.add, roperator.radd]:
699699
sep = pa.scalar("", type=pa_type)
700700
if op is operator.add:
701701
result = pc.binary_join_element_wise(self._pa_array, other, sep)
702702
elif op is roperator.radd:
703703
result = pc.binary_join_element_wise(other, self._pa_array, sep)
704-
else:
705-
if not (
706-
isinstance(other, pa.Scalar) and pa.types.is_integer(other.type)
707-
):
708-
raise TypeError("Can only string multiply by an integer.")
709-
result = pc.binary_join_element_wise(
710-
*([self._pa_array] * other.as_py()), sep
711-
)
712704
return type(self)(result)
713-
705+
elif op in [operator.mul, roperator.rmul]:
706+
result = type(self)._evaluate_binary_repeat(self._pa_array, other)
707+
return type(self)(result)
708+
elif pa.types.is_integer(pa_type) and (
709+
pa.types.is_string(other.type) or pa.types.is_binary(other.type)
710+
):
711+
result = type(self)._evaluate_binary_repeat(other, self._pa_array)
712+
return type(self)(result)
714713
if (
715714
isinstance(other, pa.Scalar)
716715
and pc.is_null(other).as_py()
@@ -726,6 +725,13 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
726725
result = pc_func(self._pa_array, other)
727726
return type(self)(result)
728727

728+
@staticmethod
729+
def _evaluate_binary_repeat(binary, integral):
730+
if not pa.types.is_integer(integral.type):
731+
raise TypeError("Can only string multiply by an integer.")
732+
pa_integral = pc.if_else(pc.less(integral, 0), 0, integral)
733+
return pc.binary_repeat(binary, pa_integral)
734+
729735
def _logical_method(self, other, op):
730736
# For integer types `^`, `|`, `&` are bitwise operators and return
731737
# integer types. Otherwise these are boolean ops.

pandas/tests/extension/test_arrow.py

+19
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,25 @@ def test_arrowdtype_construct_from_string_type_only_one_pyarrow():
13301330
pd.Series(range(3), dtype=invalid)
13311331

13321332

1333+
def test_arrow_string_multiplication():
1334+
binary = pd.Series(["abc", "defg"], dtype=ArrowDtype(pa.string()))
1335+
repeat = pd.Series([2, -2], dtype="int64[pyarrow]")
1336+
result = binary * repeat
1337+
expected = pd.Series(["abcabc", ""], dtype=ArrowDtype(pa.string()))
1338+
tm.assert_series_equal(result, expected)
1339+
reflected_result = repeat * binary
1340+
tm.assert_series_equal(result, reflected_result)
1341+
1342+
1343+
def test_arrow_string_multiplication_scalar_repeat():
1344+
binary = pd.Series(["abc", "defg"], dtype=ArrowDtype(pa.string()))
1345+
result = binary * 2
1346+
expected = pd.Series(["abcabc", "defgdefg"], dtype=ArrowDtype(pa.string()))
1347+
tm.assert_series_equal(result, expected)
1348+
reflected_result = 2 * binary
1349+
tm.assert_series_equal(reflected_result, expected)
1350+
1351+
13331352
@pytest.mark.parametrize(
13341353
"interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
13351354
)

0 commit comments

Comments
 (0)