Skip to content

Commit c8f040d

Browse files
xhochyjorisvandenbossche
authored andcommitted
BUG: Handle construction of string ExtensionArray from lists (#27674)
1 parent 360ae1c commit c8f040d

File tree

4 files changed

+75
-33
lines changed

4 files changed

+75
-33
lines changed

pandas/core/construction.py

+20-23
Original file line numberDiff line numberDiff line change
@@ -468,30 +468,27 @@ def sanitize_array(data, index, dtype=None, copy=False, raise_cast_failure=False
468468
else:
469469
subarr = com.asarray_tuplesafe(data, dtype=dtype)
470470

471-
# This is to prevent mixed-type Series getting all casted to
472-
# NumPy string type, e.g. NaN --> '-1#IND'.
473-
if issubclass(subarr.dtype.type, str):
474-
# GH#16605
475-
# If not empty convert the data to dtype
476-
# GH#19853: If data is a scalar, subarr has already the result
477-
if not lib.is_scalar(data):
478-
if not np.all(isna(data)):
479-
data = np.array(data, dtype=dtype, copy=False)
480-
subarr = np.array(data, dtype=object, copy=copy)
481-
482-
if (
483-
not (is_extension_array_dtype(subarr.dtype) or is_extension_array_dtype(dtype))
484-
and is_object_dtype(subarr.dtype)
485-
and not is_object_dtype(dtype)
486-
):
487-
inferred = lib.infer_dtype(subarr, skipna=False)
488-
if inferred == "period":
489-
from pandas.core.arrays import period_array
471+
if not (is_extension_array_dtype(subarr.dtype) or is_extension_array_dtype(dtype)):
472+
# This is to prevent mixed-type Series getting all casted to
473+
# NumPy string type, e.g. NaN --> '-1#IND'.
474+
if issubclass(subarr.dtype.type, str):
475+
# GH#16605
476+
# If not empty convert the data to dtype
477+
# GH#19853: If data is a scalar, subarr has already the result
478+
if not lib.is_scalar(data):
479+
if not np.all(isna(data)):
480+
data = np.array(data, dtype=dtype, copy=False)
481+
subarr = np.array(data, dtype=object, copy=copy)
490482

491-
try:
492-
subarr = period_array(subarr)
493-
except IncompatibleFrequency:
494-
pass
483+
if is_object_dtype(subarr.dtype) and not is_object_dtype(dtype):
484+
inferred = lib.infer_dtype(subarr, skipna=False)
485+
if inferred == "period":
486+
from pandas.core.arrays import period_array
487+
488+
try:
489+
subarr = period_array(subarr)
490+
except IncompatibleFrequency:
491+
pass
495492

496493
return subarr
497494

pandas/tests/extension/arrow/bool.py renamed to pandas/tests/extension/arrow/arrays.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,27 @@ def _is_boolean(self):
4343
return True
4444

4545

46-
class ArrowBoolArray(ExtensionArray):
47-
def __init__(self, values):
48-
if not isinstance(values, pa.ChunkedArray):
49-
raise ValueError
46+
@register_extension_dtype
47+
class ArrowStringDtype(ExtensionDtype):
5048

51-
assert values.type == pa.bool_()
52-
self._data = values
53-
self._dtype = ArrowBoolDtype()
49+
type = str
50+
kind = "U"
51+
name = "arrow_string"
52+
na_value = pa.NULL
53+
54+
@classmethod
55+
def construct_from_string(cls, string):
56+
if string == cls.name:
57+
return cls()
58+
else:
59+
raise TypeError("Cannot construct a '{}' from '{}'".format(cls, string))
60+
61+
@classmethod
62+
def construct_array_type(cls):
63+
return ArrowStringArray
5464

55-
def __repr__(self):
56-
return "ArrowBoolArray({})".format(repr(self._data))
5765

66+
class ArrowExtensionArray(ExtensionArray):
5867
@classmethod
5968
def from_scalars(cls, values):
6069
arr = pa.chunked_array([pa.array(np.asarray(values))])
@@ -69,6 +78,9 @@ def from_array(cls, arr):
6978
def _from_sequence(cls, scalars, dtype=None, copy=False):
7079
return cls.from_scalars(scalars)
7180

81+
def __repr__(self):
82+
return "{cls}({data})".format(cls=type(self).__name__, data=repr(self._data))
83+
7284
def __getitem__(self, item):
7385
if pd.api.types.is_scalar(item):
7486
return self._data.to_pandas()[item]
@@ -142,3 +154,23 @@ def any(self, axis=0, out=None):
142154

143155
def all(self, axis=0, out=None):
144156
return self._data.to_pandas().all()
157+
158+
159+
class ArrowBoolArray(ArrowExtensionArray):
160+
def __init__(self, values):
161+
if not isinstance(values, pa.ChunkedArray):
162+
raise ValueError
163+
164+
assert values.type == pa.bool_()
165+
self._data = values
166+
self._dtype = ArrowBoolDtype()
167+
168+
169+
class ArrowStringArray(ArrowExtensionArray):
170+
def __init__(self, values):
171+
if not isinstance(values, pa.ChunkedArray):
172+
raise ValueError
173+
174+
assert values.type == pa.string()
175+
self._data = values
176+
self._dtype = ArrowStringDtype()

pandas/tests/extension/arrow/test_bool.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
pytest.importorskip("pyarrow", minversion="0.10.0")
99

10-
from .bool import ArrowBoolArray, ArrowBoolDtype # isort:skip
10+
from .arrays import ArrowBoolArray, ArrowBoolDtype # isort:skip
1111

1212

1313
@pytest.fixture
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
3+
import pandas as pd
4+
5+
pytest.importorskip("pyarrow", minversion="0.10.0")
6+
7+
from .arrays import ArrowStringDtype # isort:skip
8+
9+
10+
def test_constructor_from_list():
11+
# GH 27673
12+
result = pd.Series(["E"], dtype=ArrowStringDtype())
13+
assert isinstance(result.dtype, ArrowStringDtype)

0 commit comments

Comments
 (0)