Skip to content

REF: more explicit dtypes in strings.accessor #41727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 9, 2021
34 changes: 24 additions & 10 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import numpy as np

import pandas._libs.lib as lib
from pandas._typing import FrameOrSeriesUnion
from pandas._typing import (
DtypeObj,
FrameOrSeriesUnion,
)
from pandas.util._decorators import Appender

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -209,8 +212,12 @@ def _validate(data):
# see _libs/lib.pyx for list of inferred types
allowed_types = ["string", "empty", "bytes", "mixed", "mixed-integer"]

values = getattr(data, "values", data) # Series / Index
values = getattr(values, "categories", values) # categorical / normal
# TODO: avoid kludge for tests.extension.test_numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm cn you avoid this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can revert to use the getattr pattern, either way is an anti-pattern

from pandas.core.internals.managers import _extract_array

data = _extract_array(data)

values = getattr(data, "categories", data) # categorical / normal

inferred_dtype = lib.infer_dtype(values, skipna=True)

Expand Down Expand Up @@ -242,6 +249,7 @@ def _wrap_result(
expand: bool | None = None,
fill_value=np.nan,
returns_string=True,
returns_bool: bool = False,
):
from pandas import (
Index,
Expand Down Expand Up @@ -319,19 +327,25 @@ def cons_row(x):
else:
index = self._orig.index
# This is a mess.
dtype: str | None
if self._is_string and returns_string:
dtype = self._orig.dtype
dtype: DtypeObj | str | None
vdtype = getattr(result, "dtype", None)
if self._is_string:
if is_bool_dtype(vdtype):
dtype = result.dtype
elif returns_string:
dtype = self._orig.dtype
else:
dtype = vdtype
else:
dtype = None
dtype = vdtype

if expand:
cons = self._orig._constructor_expanddim
result = cons(result, columns=name, index=index, dtype=dtype)
else:
# Must be a Series
cons = self._orig._constructor
result = cons(result, name=name, index=index)
result = cons(result, name=name, index=index, dtype=dtype)
result = result.__finalize__(self._orig, method="str")
if name is not None and result.ndim == 1:
# __finalize__ might copy over the original name, but we may
Expand Down Expand Up @@ -369,7 +383,7 @@ def _get_series_list(self, others):
if isinstance(others, ABCSeries):
return [others]
elif isinstance(others, ABCIndex):
return [Series(others._values, index=idx)]
return [Series(others._values, index=idx, dtype=others.dtype)]
elif isinstance(others, ABCDataFrame):
return [others[x] for x in others]
elif isinstance(others, np.ndarray) and others.ndim == 2:
Expand Down Expand Up @@ -547,7 +561,7 @@ def cat(self, others=None, sep=None, na_rep=None, join="left"):
sep = ""

if isinstance(self._orig, ABCIndex):
data = Series(self._orig, index=self._orig)
data = Series(self._orig, index=self._orig, dtype=self._orig.dtype)
else: # Series
data = self._orig

Expand Down