diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index c2b2454c1b858..028b88cb0b1d8 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -1,12 +1,12 @@ from __future__ import annotations import codecs +from collections.abc import Callable # noqa: PDF001 from functools import wraps import re from typing import ( TYPE_CHECKING, Hashable, - Pattern, ) import warnings @@ -1217,7 +1217,15 @@ def fullmatch(self, pat, case=True, flags=0, na=None): return self._wrap_result(result, fill_value=na, returns_string=False) @forbid_nonstring_types(["bytes"]) - def replace(self, pat, repl, n=-1, case=None, flags=0, regex=None): + def replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool | None = None, + flags: int = 0, + regex: bool | None = None, + ): r""" Replace each occurrence of pattern/regex in the Series/Index. @@ -1358,14 +1366,10 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=None): is_compiled_re = is_re(pat) if regex: - if is_compiled_re: - if (case is not None) or (flags != 0): - raise ValueError( - "case and flags cannot be set when pat is a compiled regex" - ) - elif case is None: - # not a compiled regex, set default case - case = True + if is_compiled_re and (case is not None or flags != 0): + raise ValueError( + "case and flags cannot be set when pat is a compiled regex" + ) elif is_compiled_re: raise ValueError( @@ -1374,6 +1378,9 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=None): elif callable(repl): raise ValueError("Cannot use a callable replacement when regex=False") + if case is None: + case = True + result = self._data.array._str_replace( pat, repl, n=n, case=case, flags=flags, regex=regex ) @@ -3044,14 +3051,14 @@ def _result_dtype(arr): return object -def _get_single_group_name(regex: Pattern) -> Hashable: +def _get_single_group_name(regex: re.Pattern) -> Hashable: if regex.groupindex: return next(iter(regex.groupindex)) else: return None -def _get_group_names(regex: Pattern) -> list[Hashable]: +def _get_group_names(regex: re.Pattern) -> list[Hashable]: """ Get named groups from compiled regex. diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index add156efc0263..730870b448cb2 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -1,7 +1,8 @@ from __future__ import annotations import abc -from typing import Pattern +from collections.abc import Callable # noqa: PDF001 +import re import numpy as np @@ -51,7 +52,15 @@ def _str_endswith(self, pat, na=None): pass @abc.abstractmethod - def _str_replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): + def _str_replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool = True, + flags: int = 0, + regex: bool = True, + ): pass @abc.abstractmethod @@ -67,7 +76,7 @@ def _str_match( @abc.abstractmethod def _str_fullmatch( self, - pat: str | Pattern, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar = np.nan, diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 401e0217d5adf..fb9fd77d21732 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -1,8 +1,8 @@ from __future__ import annotations +from collections.abc import Callable # noqa: PDF001 import re import textwrap -from typing import Pattern import unicodedata import numpy as np @@ -15,10 +15,7 @@ Scalar, ) -from pandas.core.dtypes.common import ( - is_re, - is_scalar, -) +from pandas.core.dtypes.common import is_scalar from pandas.core.dtypes.missing import isna from pandas.core.strings.base import BaseStringArrayMethods @@ -135,15 +132,23 @@ def _str_endswith(self, pat, na=None): f = lambda x: x.endswith(pat) return self._str_map(f, na_value=na, dtype=np.dtype(bool)) - def _str_replace(self, pat, repl, n=-1, case: bool = True, flags=0, regex=True): - is_compiled_re = is_re(pat) - + def _str_replace( + self, + pat: str | re.Pattern, + repl: str | Callable, + n: int = -1, + case: bool = True, + flags: int = 0, + regex: bool = True, + ): if case is False: # add case flag, if provided flags |= re.IGNORECASE - if regex and (is_compiled_re or len(pat) > 1 or flags or callable(repl)): - if not is_compiled_re: + if regex and ( + isinstance(pat, re.Pattern) or len(pat) > 1 or flags or callable(repl) + ): + if not isinstance(pat, re.Pattern): pat = re.compile(pat, flags=flags) n = n if n >= 0 else 0 @@ -195,7 +200,7 @@ def _str_match( def _str_fullmatch( self, - pat: str | Pattern, + pat: str | re.Pattern, case: bool = True, flags: int = 0, na: Scalar = None,