Skip to content

Commit a712c50

Browse files
authored
Fix series.str.startswith(tuple) (#48587)
* accept both str and tuple[str, ...] in series.str.(starts|ends)with also add type hints and update doc strings to note pat accepts tuple * parametrize test_startswith() and test_endswith() to include pat as tuple * change na type hint to Scalar | None + add tuple usage examples
1 parent b5632fb commit a712c50

File tree

3 files changed

+40
-17
lines changed

3 files changed

+40
-17
lines changed

pandas/core/strings/accessor.py

+31-10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pandas._typing import (
2020
DtypeObj,
2121
F,
22+
Scalar,
2223
)
2324
from pandas.util._decorators import (
2425
Appender,
@@ -2288,16 +2289,19 @@ def count(self, pat, flags=0):
22882289
return self._wrap_result(result, returns_string=False)
22892290

22902291
@forbid_nonstring_types(["bytes"])
2291-
def startswith(self, pat, na=None):
2292+
def startswith(
2293+
self, pat: str | tuple[str, ...], na: Scalar | None = None
2294+
) -> Series | Index:
22922295
"""
22932296
Test if the start of each string element matches a pattern.
22942297
22952298
Equivalent to :meth:`str.startswith`.
22962299
22972300
Parameters
22982301
----------
2299-
pat : str
2300-
Character sequence. Regular expressions are not accepted.
2302+
pat : str or tuple[str, ...]
2303+
Character sequence or tuple of strings. Regular expressions are not
2304+
accepted.
23012305
na : object, default NaN
23022306
Object shown if element tested is not a string. The default depends
23032307
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
@@ -2332,6 +2336,13 @@ def startswith(self, pat, na=None):
23322336
3 NaN
23332337
dtype: object
23342338
2339+
>>> s.str.startswith(('b', 'B'))
2340+
0 True
2341+
1 True
2342+
2 False
2343+
3 NaN
2344+
dtype: object
2345+
23352346
Specifying `na` to be `False` instead of `NaN`.
23362347
23372348
>>> s.str.startswith('b', na=False)
@@ -2341,23 +2352,26 @@ def startswith(self, pat, na=None):
23412352
3 False
23422353
dtype: bool
23432354
"""
2344-
if not isinstance(pat, str):
2345-
msg = f"expected a string object, not {type(pat).__name__}"
2355+
if not isinstance(pat, (str, tuple)):
2356+
msg = f"expected a string or tuple, not {type(pat).__name__}"
23462357
raise TypeError(msg)
23472358
result = self._data.array._str_startswith(pat, na=na)
23482359
return self._wrap_result(result, returns_string=False)
23492360

23502361
@forbid_nonstring_types(["bytes"])
2351-
def endswith(self, pat, na=None):
2362+
def endswith(
2363+
self, pat: str | tuple[str, ...], na: Scalar | None = None
2364+
) -> Series | Index:
23522365
"""
23532366
Test if the end of each string element matches a pattern.
23542367
23552368
Equivalent to :meth:`str.endswith`.
23562369
23572370
Parameters
23582371
----------
2359-
pat : str
2360-
Character sequence. Regular expressions are not accepted.
2372+
pat : str or tuple[str, ...]
2373+
Character sequence or tuple of strings. Regular expressions are not
2374+
accepted.
23612375
na : object, default NaN
23622376
Object shown if element tested is not a string. The default depends
23632377
on dtype of the array. For object-dtype, ``numpy.nan`` is used.
@@ -2392,6 +2406,13 @@ def endswith(self, pat, na=None):
23922406
3 NaN
23932407
dtype: object
23942408
2409+
>>> s.str.endswith(('t', 'T'))
2410+
0 True
2411+
1 False
2412+
2 True
2413+
3 NaN
2414+
dtype: object
2415+
23952416
Specifying `na` to be `False` instead of `NaN`.
23962417
23972418
>>> s.str.endswith('t', na=False)
@@ -2401,8 +2422,8 @@ def endswith(self, pat, na=None):
24012422
3 False
24022423
dtype: bool
24032424
"""
2404-
if not isinstance(pat, str):
2405-
msg = f"expected a string object, not {type(pat).__name__}"
2425+
if not isinstance(pat, (str, tuple)):
2426+
msg = f"expected a string or tuple, not {type(pat).__name__}"
24062427
raise TypeError(msg)
24072428
result = self._data.array._str_endswith(pat, na=na)
24082429
return self._wrap_result(result, returns_string=False)

pandas/tests/strings/test_find_replace.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -291,21 +291,22 @@ def test_contains_nan(any_string_dtype):
291291
# --------------------------------------------------------------------------------------
292292

293293

294+
@pytest.mark.parametrize("pat", ["foo", ("foo", "baz")])
294295
@pytest.mark.parametrize("dtype", [None, "category"])
295296
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
296297
@pytest.mark.parametrize("na", [True, False])
297-
def test_startswith(dtype, null_value, na):
298+
def test_startswith(pat, dtype, null_value, na):
298299
# add category dtype parametrizations for GH-36241
299300
values = Series(
300301
["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"],
301302
dtype=dtype,
302303
)
303304

304-
result = values.str.startswith("foo")
305+
result = values.str.startswith(pat)
305306
exp = Series([False, np.nan, True, False, False, np.nan, True])
306307
tm.assert_series_equal(result, exp)
307308

308-
result = values.str.startswith("foo", na=na)
309+
result = values.str.startswith(pat, na=na)
309310
exp = Series([False, na, True, False, False, na, True])
310311
tm.assert_series_equal(result, exp)
311312

@@ -351,21 +352,22 @@ def test_startswith_nullable_string_dtype(nullable_string_dtype, na):
351352
# --------------------------------------------------------------------------------------
352353

353354

355+
@pytest.mark.parametrize("pat", ["foo", ("foo", "baz")])
354356
@pytest.mark.parametrize("dtype", [None, "category"])
355357
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
356358
@pytest.mark.parametrize("na", [True, False])
357-
def test_endswith(dtype, null_value, na):
359+
def test_endswith(pat, dtype, null_value, na):
358360
# add category dtype parametrizations for GH-36241
359361
values = Series(
360362
["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"],
361363
dtype=dtype,
362364
)
363365

364-
result = values.str.endswith("foo")
366+
result = values.str.endswith(pat)
365367
exp = Series([False, np.nan, False, False, True, np.nan, True])
366368
tm.assert_series_equal(result, exp)
367369

368-
result = values.str.endswith("foo", na=na)
370+
result = values.str.endswith(pat, na=na)
369371
exp = Series([False, na, False, False, True, na, True])
370372
tm.assert_series_equal(result, exp)
371373

pandas/tests/strings/test_strings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def test_startswith_endswith_non_str_patterns(pattern):
2727
# GH3485
2828
ser = Series(["foo", "bar"])
29-
msg = f"expected a string object, not {type(pattern).__name__}"
29+
msg = f"expected a string or tuple, not {type(pattern).__name__}"
3030
with pytest.raises(TypeError, match=msg):
3131
ser.str.startswith(pattern)
3232
with pytest.raises(TypeError, match=msg):

0 commit comments

Comments
 (0)