diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 7643019ff8c55..29d37599b0785 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -13,7 +13,10 @@ import numpy as np import pandas._libs.lib as lib -from pandas._typing import FrameOrSeriesUnion +from pandas._typing import ( + DtypeObj, + FrameOrSeriesUnion, +) from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -209,8 +212,12 @@ def _validate(data): # see _libs/lib.pyx for list of inferred types allowed_types = ["string", "empty", "bytes", "mixed", "mixed-integer"] - values = getattr(data, "values", data) # Series / Index - values = getattr(values, "categories", values) # categorical / normal + # TODO: avoid kludge for tests.extension.test_numpy + from pandas.core.internals.managers import _extract_array + + data = _extract_array(data) + + values = getattr(data, "categories", data) # categorical / normal inferred_dtype = lib.infer_dtype(values, skipna=True) @@ -242,6 +249,7 @@ def _wrap_result( expand: bool | None = None, fill_value=np.nan, returns_string=True, + returns_bool: bool = False, ): from pandas import ( Index, @@ -319,11 +327,17 @@ def cons_row(x): else: index = self._orig.index # This is a mess. - dtype: str | None - if self._is_string and returns_string: - dtype = self._orig.dtype + dtype: DtypeObj | str | None + vdtype = getattr(result, "dtype", None) + if self._is_string: + if is_bool_dtype(vdtype): + dtype = result.dtype + elif returns_string: + dtype = self._orig.dtype + else: + dtype = vdtype else: - dtype = None + dtype = vdtype if expand: cons = self._orig._constructor_expanddim @@ -331,7 +345,7 @@ def cons_row(x): else: # Must be a Series cons = self._orig._constructor - result = cons(result, name=name, index=index) + result = cons(result, name=name, index=index, dtype=dtype) result = result.__finalize__(self._orig, method="str") if name is not None and result.ndim == 1: # __finalize__ might copy over the original name, but we may @@ -369,7 +383,7 @@ def _get_series_list(self, others): if isinstance(others, ABCSeries): return [others] elif isinstance(others, ABCIndex): - return [Series(others._values, index=idx)] + return [Series(others._values, index=idx, dtype=others.dtype)] elif isinstance(others, ABCDataFrame): return [others[x] for x in others] elif isinstance(others, np.ndarray) and others.ndim == 2: @@ -547,7 +561,7 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"): sep = "" if isinstance(self._orig, ABCIndex): - data = Series(self._orig, index=self._orig) + data = Series(self._orig, index=self._orig, dtype=self._orig.dtype) else: # Series data = self._orig