diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 180ed51e7fd2b..5d505691b13d2 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -831,3 +831,30 @@ def _str_lower(self): def _str_upper(self): return type(self)(pc.utf8_upper(self._data)) + + def _str_strip(self, to_strip=None): + if to_strip is None: + if hasattr(pc, "utf8_trim_whitespace"): + return type(self)(pc.utf8_trim_whitespace(self._data)) + else: + if hasattr(pc, "utf8_trim"): + return type(self)(pc.utf8_trim(self._data, characters=to_strip)) + return super()._str_strip(to_strip) + + def _str_lstrip(self, to_strip=None): + if to_strip is None: + if hasattr(pc, "utf8_ltrim_whitespace"): + return type(self)(pc.utf8_ltrim_whitespace(self._data)) + else: + if hasattr(pc, "utf8_ltrim"): + return type(self)(pc.utf8_ltrim(self._data, characters=to_strip)) + return super()._str_lstrip(to_strip) + + def _str_rstrip(self, to_strip=None): + if to_strip is None: + if hasattr(pc, "utf8_rtrim_whitespace"): + return type(self)(pc.utf8_rtrim_whitespace(self._data)) + else: + if hasattr(pc, "utf8_rtrim"): + return type(self)(pc.utf8_rtrim(self._data, characters=to_strip)) + return super()._str_rstrip(to_strip) diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 2a52b3ba3f9e1..f218d5333b415 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -570,19 +570,19 @@ def test_slice_replace(): tm.assert_series_equal(result, exp) -def test_strip_lstrip_rstrip(): - values = Series([" aa ", " bb \n", np.nan, "cc "]) +def test_strip_lstrip_rstrip(any_string_dtype): + values = Series([" aa ", " bb \n", np.nan, "cc "], dtype=any_string_dtype) result = values.str.strip() - exp = Series(["aa", "bb", np.nan, "cc"]) + exp = Series(["aa", "bb", np.nan, "cc"], dtype=any_string_dtype) tm.assert_series_equal(result, exp) result = values.str.lstrip() - exp = Series(["aa ", "bb \n", np.nan, "cc "]) + exp = Series(["aa ", "bb \n", np.nan, "cc "], dtype=any_string_dtype) tm.assert_series_equal(result, exp) result = values.str.rstrip() - exp = Series([" aa", " bb", np.nan, "cc"]) + exp = Series([" aa", " bb", np.nan, "cc"], dtype=any_string_dtype) tm.assert_series_equal(result, exp) @@ -609,19 +609,19 @@ def test_strip_lstrip_rstrip_mixed(): tm.assert_almost_equal(rs, xp) -def test_strip_lstrip_rstrip_args(): - values = Series(["xxABCxx", "xx BNSD", "LDFJH xx"]) +def test_strip_lstrip_rstrip_args(any_string_dtype): + values = Series(["xxABCxx", "xx BNSD", "LDFJH xx"], dtype=any_string_dtype) rs = values.str.strip("x") - xp = Series(["ABC", " BNSD", "LDFJH "]) + xp = Series(["ABC", " BNSD", "LDFJH "], dtype=any_string_dtype) tm.assert_series_equal(rs, xp) rs = values.str.lstrip("x") - xp = Series(["ABCxx", " BNSD", "LDFJH xx"]) + xp = Series(["ABCxx", " BNSD", "LDFJH xx"], dtype=any_string_dtype) tm.assert_series_equal(rs, xp) rs = values.str.rstrip("x") - xp = Series(["xxABC", "xx BNSD", "LDFJH "]) + xp = Series(["xxABC", "xx BNSD", "LDFJH "], dtype=any_string_dtype) tm.assert_series_equal(rs, xp)