Skip to content

Commit f661a2f

Browse files
TomAugspurgerKevin D Smith
authored and
Kevin D Smith
committed
REF: Dispatch string methods to ExtensionArray (pandas-dev#36357)
1 parent 9666fff commit f661a2f

File tree

10 files changed

+3942
-3716
lines changed

10 files changed

+3942
-3716
lines changed

ci/code_checks.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ if [[ -z "$CHECK" || "$CHECK" == "doctests" ]]; then
335335
RET=$(($RET + $?)) ; echo $MSG "DONE"
336336

337337
MSG='Doctests strings.py' ; echo $MSG
338-
pytest -q --doctest-modules pandas/core/strings.py
338+
pytest -q --doctest-modules pandas/core/strings/
339339
RET=$(($RET + $?)) ; echo $MSG "DONE"
340340

341341
# Directories

pandas/core/arrays/categorical.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pandas.core.missing import interpolate_2d
5252
from pandas.core.ops.common import unpack_zerodim_and_defer
5353
from pandas.core.sorting import nargsort
54+
from pandas.core.strings.object_array import ObjectStringArrayMixin
5455

5556
from pandas.io.formats import console
5657

@@ -176,7 +177,7 @@ def contains(cat, key, container):
176177
return any(loc_ in container for loc_ in loc)
177178

178179

179-
class Categorical(NDArrayBackedExtensionArray, PandasObject):
180+
class Categorical(NDArrayBackedExtensionArray, PandasObject, ObjectStringArrayMixin):
180181
"""
181182
Represent a categorical variable in classic R / S-plus fashion.
182183
@@ -2305,6 +2306,25 @@ def replace(self, to_replace, value, inplace: bool = False):
23052306
if not inplace:
23062307
return cat
23072308

2309+
# ------------------------------------------------------------------------
2310+
# String methods interface
2311+
def _str_map(self, f, na_value=np.nan, dtype=np.dtype(object)):
2312+
# Optimization to apply the callable `f` to the categories once
2313+
# and rebuild the result by `take`ing from the result with the codes.
2314+
# Returns the same type as the object-dtype implementation though.
2315+
from pandas.core.arrays import PandasArray
2316+
2317+
categories = self.categories
2318+
codes = self.codes
2319+
result = PandasArray(categories.to_numpy())._str_map(f, na_value, dtype)
2320+
return take_1d(result, codes, fill_value=na_value)
2321+
2322+
def _str_get_dummies(self, sep="|"):
2323+
# sep may not be in categories. Just bail on this.
2324+
from pandas.core.arrays import PandasArray
2325+
2326+
return PandasArray(self.astype(str))._str_get_dummies(sep)
2327+
23082328

23092329
# The Series.cat accessor
23102330

pandas/core/arrays/numpy_.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pandas.core.array_algos import masked_reductions
1717
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
1818
from pandas.core.arrays.base import ExtensionOpsMixin
19+
from pandas.core.strings.object_array import ObjectStringArrayMixin
1920

2021

2122
class PandasDtype(ExtensionDtype):
@@ -114,7 +115,10 @@ def itemsize(self) -> int:
114115

115116

116117
class PandasArray(
117-
NDArrayBackedExtensionArray, ExtensionOpsMixin, NDArrayOperatorsMixin
118+
NDArrayBackedExtensionArray,
119+
ExtensionOpsMixin,
120+
NDArrayOperatorsMixin,
121+
ObjectStringArrayMixin,
118122
):
119123
"""
120124
A pandas ExtensionArray for NumPy data.
@@ -376,6 +380,10 @@ def arithmetic_method(self, other):
376380

377381
_create_comparison_method = _create_arithmetic_method
378382

383+
# ------------------------------------------------------------------------
384+
# String methods interface
385+
_str_na_value = np.nan
386+
379387

380388
PandasArray._add_arithmetic_ops()
381389
PandasArray._add_comparison_ops()

pandas/core/arrays/string_.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66
from pandas._libs import lib, missing as libmissing
77

88
from pandas.core.dtypes.base import ExtensionDtype, register_extension_dtype
9-
from pandas.core.dtypes.common import pandas_dtype
10-
from pandas.core.dtypes.inference import is_array_like
9+
from pandas.core.dtypes.common import (
10+
is_array_like,
11+
is_bool_dtype,
12+
is_integer_dtype,
13+
is_object_dtype,
14+
is_string_dtype,
15+
pandas_dtype,
16+
)
1117

1218
from pandas import compat
1319
from pandas.core import ops
@@ -347,6 +353,58 @@ def _add_arithmetic_ops(cls):
347353
cls.__rmul__ = cls._create_arithmetic_method(ops.rmul)
348354

349355
_create_comparison_method = _create_arithmetic_method
356+
# ------------------------------------------------------------------------
357+
# String methods interface
358+
_str_na_value = StringDtype.na_value
359+
360+
def _str_map(self, f, na_value=None, dtype=None):
361+
from pandas.arrays import BooleanArray, IntegerArray, StringArray
362+
from pandas.core.arrays.string_ import StringDtype
363+
364+
if dtype is None:
365+
dtype = StringDtype()
366+
if na_value is None:
367+
na_value = self.dtype.na_value
368+
369+
mask = isna(self)
370+
arr = np.asarray(self)
371+
372+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
373+
constructor: Union[Type[IntegerArray], Type[BooleanArray]]
374+
if is_integer_dtype(dtype):
375+
constructor = IntegerArray
376+
else:
377+
constructor = BooleanArray
378+
379+
na_value_is_na = isna(na_value)
380+
if na_value_is_na:
381+
na_value = 1
382+
result = lib.map_infer_mask(
383+
arr,
384+
f,
385+
mask.view("uint8"),
386+
convert=False,
387+
na_value=na_value,
388+
dtype=np.dtype(dtype),
389+
)
390+
391+
if not na_value_is_na:
392+
mask[:] = False
393+
394+
return constructor(result, mask)
395+
396+
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
397+
# i.e. StringDtype
398+
result = lib.map_infer_mask(
399+
arr, f, mask.view("uint8"), convert=False, na_value=na_value
400+
)
401+
return StringArray(result)
402+
else:
403+
# This is when the result type is object. We reach this when
404+
# -> We know the result type is truly object (e.g. .encode returns bytes
405+
# or .findall returns a list).
406+
# -> We don't know the result type. E.g. `.get` can return anything.
407+
return lib.map_infer_mask(arr, f, mask.view("uint8"))
350408

351409

352410
StringArray._add_arithmetic_ops()

0 commit comments

Comments
 (0)