Skip to content

Handle construction of string ExtensionArray from lists #27674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions pandas/core/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,30 +468,27 @@ def sanitize_array(data, index, dtype=None, copy=False, raise_cast_failure=False
else:
subarr = com.asarray_tuplesafe(data, dtype=dtype)

# This is to prevent mixed-type Series getting all casted to
# NumPy string type, e.g. NaN --> '-1#IND'.
if issubclass(subarr.dtype.type, str):
# GH#16605
# If not empty convert the data to dtype
# GH#19853: If data is a scalar, subarr has already the result
if not lib.is_scalar(data):
if not np.all(isna(data)):
data = np.array(data, dtype=dtype, copy=False)
subarr = np.array(data, dtype=object, copy=copy)

if (
not (is_extension_array_dtype(subarr.dtype) or is_extension_array_dtype(dtype))
and is_object_dtype(subarr.dtype)
and not is_object_dtype(dtype)
):
inferred = lib.infer_dtype(subarr, skipna=False)
if inferred == "period":
from pandas.core.arrays import period_array
if not (is_extension_array_dtype(subarr.dtype) or is_extension_array_dtype(dtype)):
# This is to prevent mixed-type Series getting all casted to
# NumPy string type, e.g. NaN --> '-1#IND'.
if issubclass(subarr.dtype.type, str):
# GH#16605
# If not empty convert the data to dtype
# GH#19853: If data is a scalar, subarr has already the result
if not lib.is_scalar(data):
if not np.all(isna(data)):
data = np.array(data, dtype=dtype, copy=False)
subarr = np.array(data, dtype=object, copy=copy)

try:
subarr = period_array(subarr)
except IncompatibleFrequency:
pass
if is_object_dtype(subarr.dtype) and not is_object_dtype(dtype):
inferred = lib.infer_dtype(subarr, skipna=False)
if inferred == "period":
from pandas.core.arrays import period_array

try:
subarr = period_array(subarr)
except IncompatibleFrequency:
pass

return subarr

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,27 @@ def _is_boolean(self):
return True


class ArrowBoolArray(ExtensionArray):
def __init__(self, values):
if not isinstance(values, pa.ChunkedArray):
raise ValueError
@register_extension_dtype
class ArrowStringDtype(ExtensionDtype):

assert values.type == pa.bool_()
self._data = values
self._dtype = ArrowBoolDtype()
type = str
kind = "U"
name = "arrow_string"
na_value = pa.NULL

@classmethod
def construct_from_string(cls, string):
if string == cls.name:
return cls()
else:
raise TypeError("Cannot construct a '{}' from '{}'".format(cls, string))

@classmethod
def construct_array_type(cls):
return ArrowStringArray

def __repr__(self):
return "ArrowBoolArray({})".format(repr(self._data))

class ArrowExtensionArray(ExtensionArray):
@classmethod
def from_scalars(cls, values):
arr = pa.chunked_array([pa.array(np.asarray(values))])
Expand All @@ -69,6 +78,9 @@ def from_array(cls, arr):
def _from_sequence(cls, scalars, dtype=None, copy=False):
return cls.from_scalars(scalars)

def __repr__(self):
return "{cls}({data})".format(cls=type(self).__name__, data=repr(self._data))

def __getitem__(self, item):
if pd.api.types.is_scalar(item):
return self._data.to_pandas()[item]
Expand Down Expand Up @@ -142,3 +154,23 @@ def any(self, axis=0, out=None):

def all(self, axis=0, out=None):
return self._data.to_pandas().all()


class ArrowBoolArray(ArrowExtensionArray):
def __init__(self, values):
if not isinstance(values, pa.ChunkedArray):
raise ValueError

assert values.type == pa.bool_()
self._data = values
self._dtype = ArrowBoolDtype()


class ArrowStringArray(ArrowExtensionArray):
def __init__(self, values):
if not isinstance(values, pa.ChunkedArray):
raise ValueError

assert values.type == pa.string()
self._data = values
self._dtype = ArrowStringDtype()
2 changes: 1 addition & 1 deletion pandas/tests/extension/arrow/test_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

pytest.importorskip("pyarrow", minversion="0.10.0")

from .bool import ArrowBoolArray, ArrowBoolDtype # isort:skip
from .arrays import ArrowBoolArray, ArrowBoolDtype # isort:skip


@pytest.fixture
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/extension/arrow/test_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

import pandas as pd

pytest.importorskip("pyarrow", minversion="0.10.0")

from .arrays import ArrowStringDtype # isort:skip


def test_constructor_from_list():
# GH 27673
result = pd.Series(["E"], dtype=ArrowStringDtype())
assert isinstance(result.dtype, ArrowStringDtype)