Skip to content

Commit 70d4697

Browse files
[ArrowStringArray] TYP: add annotations to str.replace (#41603)
1 parent 04d0e48 commit 70d4697

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

pandas/core/strings/accessor.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

33
import codecs
4+
from collections.abc import Callable # noqa: PDF001
45
from functools import wraps
56
import re
67
from typing import (
78
TYPE_CHECKING,
89
Hashable,
9-
Pattern,
1010
)
1111
import warnings
1212

@@ -1217,7 +1217,15 @@ def fullmatch(self, pat, case=True, flags=0, na=None):
12171217
return self._wrap_result(result, fill_value=na, returns_string=False)
12181218

12191219
@forbid_nonstring_types(["bytes"])
1220-
def replace(self, pat, repl, n=-1, case=None, flags=0, regex=None):
1220+
def replace(
1221+
self,
1222+
pat: str | re.Pattern,
1223+
repl: str | Callable,
1224+
n: int = -1,
1225+
case: bool | None = None,
1226+
flags: int = 0,
1227+
regex: bool | None = None,
1228+
):
12211229
r"""
12221230
Replace each occurrence of pattern/regex in the Series/Index.
12231231
@@ -1358,14 +1366,10 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=None):
13581366

13591367
is_compiled_re = is_re(pat)
13601368
if regex:
1361-
if is_compiled_re:
1362-
if (case is not None) or (flags != 0):
1363-
raise ValueError(
1364-
"case and flags cannot be set when pat is a compiled regex"
1365-
)
1366-
elif case is None:
1367-
# not a compiled regex, set default case
1368-
case = True
1369+
if is_compiled_re and (case is not None or flags != 0):
1370+
raise ValueError(
1371+
"case and flags cannot be set when pat is a compiled regex"
1372+
)
13691373

13701374
elif is_compiled_re:
13711375
raise ValueError(
@@ -1374,6 +1378,9 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=None):
13741378
elif callable(repl):
13751379
raise ValueError("Cannot use a callable replacement when regex=False")
13761380

1381+
if case is None:
1382+
case = True
1383+
13771384
result = self._data.array._str_replace(
13781385
pat, repl, n=n, case=case, flags=flags, regex=regex
13791386
)
@@ -3044,14 +3051,14 @@ def _result_dtype(arr):
30443051
return object
30453052

30463053

3047-
def _get_single_group_name(regex: Pattern) -> Hashable:
3054+
def _get_single_group_name(regex: re.Pattern) -> Hashable:
30483055
if regex.groupindex:
30493056
return next(iter(regex.groupindex))
30503057
else:
30513058
return None
30523059

30533060

3054-
def _get_group_names(regex: Pattern) -> list[Hashable]:
3061+
def _get_group_names(regex: re.Pattern) -> list[Hashable]:
30553062
"""
30563063
Get named groups from compiled regex.
30573064

pandas/core/strings/base.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import abc
4-
from typing import Pattern
4+
from collections.abc import Callable # noqa: PDF001
5+
import re
56

67
import numpy as np
78

@@ -51,7 +52,15 @@ def _str_endswith(self, pat, na=None):
5152
pass
5253

5354
@abc.abstractmethod
54-
def _str_replace(self, pat, repl, n=-1, case=None, flags=0, regex=True):
55+
def _str_replace(
56+
self,
57+
pat: str | re.Pattern,
58+
repl: str | Callable,
59+
n: int = -1,
60+
case: bool = True,
61+
flags: int = 0,
62+
regex: bool = True,
63+
):
5564
pass
5665

5766
@abc.abstractmethod
@@ -67,7 +76,7 @@ def _str_match(
6776
@abc.abstractmethod
6877
def _str_fullmatch(
6978
self,
70-
pat: str | Pattern,
79+
pat: str | re.Pattern,
7180
case: bool = True,
7281
flags: int = 0,
7382
na: Scalar = np.nan,

pandas/core/strings/object_array.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3+
from collections.abc import Callable # noqa: PDF001
34
import re
45
import textwrap
5-
from typing import Pattern
66
import unicodedata
77

88
import numpy as np
@@ -15,10 +15,7 @@
1515
Scalar,
1616
)
1717

18-
from pandas.core.dtypes.common import (
19-
is_re,
20-
is_scalar,
21-
)
18+
from pandas.core.dtypes.common import is_scalar
2219
from pandas.core.dtypes.missing import isna
2320

2421
from pandas.core.strings.base import BaseStringArrayMethods
@@ -135,15 +132,23 @@ def _str_endswith(self, pat, na=None):
135132
f = lambda x: x.endswith(pat)
136133
return self._str_map(f, na_value=na, dtype=np.dtype(bool))
137134

138-
def _str_replace(self, pat, repl, n=-1, case: bool = True, flags=0, regex=True):
139-
is_compiled_re = is_re(pat)
140-
135+
def _str_replace(
136+
self,
137+
pat: str | re.Pattern,
138+
repl: str | Callable,
139+
n: int = -1,
140+
case: bool = True,
141+
flags: int = 0,
142+
regex: bool = True,
143+
):
141144
if case is False:
142145
# add case flag, if provided
143146
flags |= re.IGNORECASE
144147

145-
if regex and (is_compiled_re or len(pat) > 1 or flags or callable(repl)):
146-
if not is_compiled_re:
148+
if regex and (
149+
isinstance(pat, re.Pattern) or len(pat) > 1 or flags or callable(repl)
150+
):
151+
if not isinstance(pat, re.Pattern):
147152
pat = re.compile(pat, flags=flags)
148153

149154
n = n if n >= 0 else 0
@@ -195,7 +200,7 @@ def _str_match(
195200

196201
def _str_fullmatch(
197202
self,
198-
pat: str | Pattern,
203+
pat: str | re.Pattern,
199204
case: bool = True,
200205
flags: int = 0,
201206
na: Scalar = None,

0 commit comments

Comments
 (0)