Skip to content

Commit 9ad2fbb

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: more explicit dtypes in strings.accessor (pandas-dev#41727)
1 parent 3da2423 commit 9ad2fbb

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

pandas/core/strings/accessor.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
import numpy as np
1414

1515
import pandas._libs.lib as lib
16-
from pandas._typing import FrameOrSeriesUnion
16+
from pandas._typing import (
17+
DtypeObj,
18+
FrameOrSeriesUnion,
19+
)
1720
from pandas.util._decorators import Appender
1821

1922
from pandas.core.dtypes.common import (
@@ -209,8 +212,12 @@ def _validate(data):
209212
# see _libs/lib.pyx for list of inferred types
210213
allowed_types = ["string", "empty", "bytes", "mixed", "mixed-integer"]
211214

212-
values = getattr(data, "values", data) # Series / Index
213-
values = getattr(values, "categories", values) # categorical / normal
215+
# TODO: avoid kludge for tests.extension.test_numpy
216+
from pandas.core.internals.managers import _extract_array
217+
218+
data = _extract_array(data)
219+
220+
values = getattr(data, "categories", data) # categorical / normal
214221

215222
inferred_dtype = lib.infer_dtype(values, skipna=True)
216223

@@ -242,6 +249,7 @@ def _wrap_result(
242249
expand: bool | None = None,
243250
fill_value=np.nan,
244251
returns_string=True,
252+
returns_bool: bool = False,
245253
):
246254
from pandas import (
247255
Index,
@@ -319,19 +327,25 @@ def cons_row(x):
319327
else:
320328
index = self._orig.index
321329
# This is a mess.
322-
dtype: str | None
323-
if self._is_string and returns_string:
324-
dtype = self._orig.dtype
330+
dtype: DtypeObj | str | None
331+
vdtype = getattr(result, "dtype", None)
332+
if self._is_string:
333+
if is_bool_dtype(vdtype):
334+
dtype = result.dtype
335+
elif returns_string:
336+
dtype = self._orig.dtype
337+
else:
338+
dtype = vdtype
325339
else:
326-
dtype = None
340+
dtype = vdtype
327341

328342
if expand:
329343
cons = self._orig._constructor_expanddim
330344
result = cons(result, columns=name, index=index, dtype=dtype)
331345
else:
332346
# Must be a Series
333347
cons = self._orig._constructor
334-
result = cons(result, name=name, index=index)
348+
result = cons(result, name=name, index=index, dtype=dtype)
335349
result = result.__finalize__(self._orig, method="str")
336350
if name is not None and result.ndim == 1:
337351
# __finalize__ might copy over the original name, but we may
@@ -369,7 +383,7 @@ def _get_series_list(self, others):
369383
if isinstance(others, ABCSeries):
370384
return [others]
371385
elif isinstance(others, ABCIndex):
372-
return [Series(others._values, index=idx)]
386+
return [Series(others._values, index=idx, dtype=others.dtype)]
373387
elif isinstance(others, ABCDataFrame):
374388
return [others[x] for x in others]
375389
elif isinstance(others, np.ndarray) and others.ndim == 2:
@@ -547,7 +561,7 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"):
547561
sep = ""
548562

549563
if isinstance(self._orig, ABCIndex):
550-
data = Series(self._orig, index=self._orig)
564+
data = Series(self._orig, index=self._orig, dtype=self._orig.dtype)
551565
else: # Series
552566
data = self._orig
553567

0 commit comments

Comments
 (0)