Skip to content

Commit 1d59c7a

Browse files
add common base class, BaseStringArray
1 parent 4a37470 commit 1d59c7a

File tree

4 files changed

+17
-16
lines changed

4 files changed

+17
-16
lines changed

pandas/core/arrays/string_.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
IntegerArray,
4444
PandasArray,
4545
)
46+
from pandas.core.arrays.base import ExtensionArray
4647
from pandas.core.arrays.floating import FloatingDtype
4748
from pandas.core.arrays.integer import _IntegerDtype
4849
from pandas.core.construction import extract_array
@@ -52,8 +53,6 @@
5253
if TYPE_CHECKING:
5354
import pyarrow
5455

55-
from pandas.core.arrays.string_arrow import ArrowStringArray
56-
5756

5857
@register_extension_dtype
5958
class StringDtype(ExtensionDtype):
@@ -172,7 +171,7 @@ def __hash__(self) -> int:
172171
# "ExtensionDtype"
173172
def construct_array_type( # type: ignore[override]
174173
self,
175-
) -> type_t[StringArray | ArrowStringArray]:
174+
) -> type_t[BaseStringArray]:
176175
"""
177176
Return the array type associated with this dtype.
178177
@@ -195,7 +194,7 @@ def __str__(self):
195194

196195
def __from_arrow__(
197196
self, array: pyarrow.Array | pyarrow.ChunkedArray
198-
) -> StringArray | ArrowStringArray:
197+
) -> BaseStringArray:
199198
"""
200199
Construct StringArray from pyarrow Array/ChunkedArray.
201200
"""
@@ -225,7 +224,11 @@ def __from_arrow__(
225224
return StringArray(np.array([], dtype="object"))
226225

227226

228-
class StringArray(PandasArray):
227+
class BaseStringArray(ExtensionArray):
228+
pass
229+
230+
231+
class StringArray(BaseStringArray, PandasArray):
229232
"""
230233
Extension array for string data.
231234

pandas/core/arrays/string_arrow.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
from pandas.core.arrays.boolean import BooleanDtype
4848
from pandas.core.arrays.integer import Int64Dtype
4949
from pandas.core.arrays.numeric import NumericDtype
50-
from pandas.core.arrays.string_ import StringDtype
50+
from pandas.core.arrays.string_ import (
51+
BaseStringArray,
52+
StringDtype,
53+
)
5154
from pandas.core.indexers import (
5255
check_array_indexer,
5356
validate_indices,
@@ -86,7 +89,7 @@ def _chk_pyarrow_available() -> None:
8689
# fallback for the ones that pyarrow doesn't yet support
8790

8891

89-
class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin):
92+
class ArrowStringArray(OpsMixin, BaseStringArray, ObjectStringArrayMixin):
9093
"""
9194
Extension array for string data in a ``pyarrow.ChunkedArray``.
9295

pandas/core/dtypes/cast.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -420,18 +420,14 @@ def maybe_cast_to_extension_array(
420420
-------
421421
ExtensionArray or obj
422422
"""
423-
from pandas.core.arrays.string_ import StringArray
424-
from pandas.core.arrays.string_arrow import ArrowStringArray
423+
from pandas.core.arrays.string_ import BaseStringArray
425424

426425
assert isinstance(cls, type), f"must pass a type: {cls}"
427426
assertion_msg = f"must pass a subclass of ExtensionArray: {cls}"
428427
assert issubclass(cls, ABCExtensionArray), assertion_msg
429428

430429
# Everything can be converted to StringArrays, but we may not want to convert
431-
if (
432-
issubclass(cls, (StringArray, ArrowStringArray))
433-
and lib.infer_dtype(obj) != "string"
434-
):
430+
if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string":
435431
return obj
436432

437433
try:

pandas/core/strings/object_array.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ def scalar_rep(x):
173173

174174
return self._str_map(scalar_rep, dtype=str)
175175
else:
176-
from pandas.core.arrays.string_ import StringArray
177-
from pandas.core.arrays.string_arrow import ArrowStringArray
176+
from pandas.core.arrays.string_ import BaseStringArray
178177

179178
def rep(x, r):
180179
if x is libmissing.NA:
@@ -186,7 +185,7 @@ def rep(x, r):
186185

187186
repeats = np.asarray(repeats, dtype=object)
188187
result = libops.vec_binop(np.asarray(self), repeats, rep)
189-
if isinstance(self, (StringArray, ArrowStringArray)):
188+
if isinstance(self, BaseStringArray):
190189
# Not going through map, so we have to do this here.
191190
result = type(self)._from_sequence(result)
192191
return result

0 commit comments

Comments
 (0)