Skip to content

Commit 7969eab

Browse files
committed
Handle construction of string ExtensionArray from lists
1 parent 143bc34 commit 7969eab

File tree

4 files changed

+61
-11
lines changed

4 files changed

+61
-11
lines changed

pandas/core/construction.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,9 @@ def sanitize_array(data, index, dtype=None, copy=False, raise_cast_failure=False
470470

471471
# This is to prevent mixed-type Series getting all casted to
472472
# NumPy string type, e.g. NaN --> '-1#IND'.
473-
if issubclass(subarr.dtype.type, str):
473+
if not (
474+
is_extension_array_dtype(subarr.dtype) or is_extension_array_dtype(dtype)
475+
) and issubclass(subarr.dtype.type, str):
474476
# GH#16605
475477
# If not empty convert the data to dtype
476478
# GH#19853: If data is a scalar, subarr has already the result

pandas/tests/extension/arrow/bool.py renamed to pandas/tests/extension/arrow/arrays.py

+44-9
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,27 @@ def _is_boolean(self):
4343
return True
4444

4545

46-
class ArrowBoolArray(ExtensionArray):
47-
def __init__(self, values):
48-
if not isinstance(values, pa.ChunkedArray):
49-
raise ValueError
46+
@register_extension_dtype
47+
class ArrowStringDtype(ExtensionDtype):
5048

51-
assert values.type == pa.bool_()
52-
self._data = values
53-
self._dtype = ArrowBoolDtype()
49+
type = str
50+
kind = "U"
51+
name = "arrow_string"
52+
na_value = pa.NULL
5453

55-
def __repr__(self):
56-
return "ArrowBoolArray({})".format(repr(self._data))
54+
@classmethod
55+
def construct_from_string(cls, string):
56+
if string == cls.name:
57+
return cls()
58+
else:
59+
raise TypeError("Cannot construct a '{}' from '{}'".format(cls, string))
5760

61+
@classmethod
62+
def construct_array_type(cls):
63+
return ArrowStringArray
64+
65+
66+
class ArrowExtensionArray(ExtensionArray):
5867
@classmethod
5968
def from_scalars(cls, values):
6069
arr = pa.chunked_array([pa.array(np.asarray(values))])
@@ -142,3 +151,29 @@ def any(self, axis=0, out=None):
142151

143152
def all(self, axis=0, out=None):
144153
return self._data.to_pandas().all()
154+
155+
156+
class ArrowBoolArray(ArrowExtensionArray):
157+
def __init__(self, values):
158+
if not isinstance(values, pa.ChunkedArray):
159+
raise ValueError
160+
161+
assert values.type == pa.bool_()
162+
self._data = values
163+
self._dtype = ArrowBoolDtype()
164+
165+
def __repr__(self):
166+
return "ArrowBoolArray({})".format(repr(self._data))
167+
168+
169+
class ArrowStringArray(ArrowExtensionArray):
170+
def __init__(self, values):
171+
if not isinstance(values, pa.ChunkedArray):
172+
raise ValueError
173+
174+
assert values.type == pa.string()
175+
self._data = values
176+
self._dtype = ArrowStringDtype()
177+
178+
def __repr__(self):
179+
return "ArrowStringArray({})".format(repr(self._data))

pandas/tests/extension/arrow/test_bool.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
pytest.importorskip("pyarrow", minversion="0.10.0")
99

10-
from .bool import ArrowBoolArray, ArrowBoolDtype # isort:skip
10+
from .arrays import ArrowBoolArray, ArrowBoolDtype # isort:skip
1111

1212

1313
@pytest.fixture
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
import pandas as pd
4+
5+
pytest.importorskip("pyarrow", minversion="0.10.0")
6+
7+
from .arrays import ArrowStringDtype # isort:skip
8+
9+
10+
def test_constructor_from_list():
11+
# GH 27673
12+
result = pd.Series(["E"], dtype=ArrowStringDtype())
13+
assert isinstance(result.dtype, ArrowStringDtype)

0 commit comments

Comments
 (0)