Skip to content

TST: ArrowExtensionArray with string and binary types #49172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 19, 2022
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ Conversion
- Bug in :meth:`DataFrame.eval` incorrectly raising an ``AttributeError`` when there are negative values in function call (:issue:`46471`)
- Bug in :meth:`Series.convert_dtypes` not converting dtype to nullable dtype when :class:`Series` contains ``NA`` and has dtype ``object`` (:issue:`48791`)
- Bug where any :class:`ExtensionDtype` subclass with ``kind="M"`` would be interpreted as a timezone type (:issue:`34986`)
-
- Bug in :class:`.arrays.ArrowExtensionArray` that would raise ``NotImplementedError`` when passed a sequence of strings or binary (:issue:`49172`)

Strings
^^^^^^^
Expand Down
7 changes: 6 additions & 1 deletion pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES

# pa.float16 doesn't seem supported
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
STRING_PYARROW_DTYPES = [pa.string(), pa.utf8()]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]

TIME_PYARROW_DTYPES = [
pa.time32("s"),
Expand All @@ -225,6 +228,8 @@
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
+ DATE_PYARROW_DTYPES
+ DATETIME_PYARROW_DTYPES
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,13 @@ def _from_sequence_of_strings(
Construct a new ExtensionArray from a sequence of strings.
"""
pa_type = to_pyarrow_type(dtype)
if pa_type is None:
# Let pyarrow try to infer or raise
if (
pa_type is None
or pa.types.is_binary(pa_type)
or pa.types.is_string(pa_type)
):
# pa_type is None: Let pa.array infer
# pa_type is string/binary: scalars already correct type
scalars = strings
elif pa.types.is_timestamp(pa_type):
from pandas.core.tools.datetimes import to_datetime
Expand Down
110 changes: 100 additions & 10 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
time,
timedelta,
)
from io import (
BytesIO,
StringIO,
)

import numpy as np
import pytest
Expand Down Expand Up @@ -90,6 +94,10 @@ def data(dtype):
+ [None]
+ [time(0, 5), time(5, 0)]
)
elif pa.types.is_string(pa_dtype):
data = ["a", "b"] * 4 + [None] + ["1", "2"] * 44 + [None] + ["!", ">"]
elif pa.types.is_binary(pa_dtype):
data = [b"a", b"b"] * 4 + [None] + [b"1", b"2"] * 44 + [None] + [b"!", b">"]
else:
raise NotImplementedError
return pd.array(data, dtype=dtype)
Expand Down Expand Up @@ -155,6 +163,14 @@ def data_for_grouping(dtype):
A = time(0, 0)
B = time(0, 12)
C = time(12, 12)
elif pa.types.is_string(pa_dtype):
A = "a"
B = "b"
C = "c"
elif pa.types.is_binary(pa_dtype):
A = b"a"
B = b"b"
C = b"c"
else:
raise NotImplementedError
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
Expand Down Expand Up @@ -203,17 +219,30 @@ def na_value():


class TestBaseCasting(base.BaseCastingTests):
pass
def test_astype_str(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_binary(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"For {pa_dtype} .astype(str) decodes.",
)
)
super().test_astype_str(data)


class TestConstructors(base.BaseConstructorsTests):
def test_from_dtype(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz:
if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string(
pa_dtype
):
if pa.types.is_string(pa_dtype):
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
else:
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
request.node.add_marker(
pytest.mark.xfail(
raises=NotImplementedError,
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
reason=reason,
)
)
super().test_from_dtype(data)
Expand Down Expand Up @@ -302,7 +331,7 @@ class TestGetitemTests(base.BaseGetitemTests):
reason=(
"data.dtype.type return pyarrow.DataType "
"but this (intentionally) returns "
"Python scalars or pd.Na"
"Python scalars or pd.NA"
)
)
def test_getitem_scalar(self, data):
Expand Down Expand Up @@ -361,7 +390,11 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
or pa.types.is_boolean(pa_dtype)
) and not (
all_numeric_reductions in {"min", "max"}
and (pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
and (
(pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
)
):
request.node.add_marker(xfail_mark)
elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
Expand Down Expand Up @@ -494,6 +527,16 @@ def test_construct_from_string_own_name(self, dtype, request):
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
elif pa.types.is_string(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
"Still support StringDtype('pyarrow') "
"over ArrowDtype(pa.string())"
),
)
)
super().test_construct_from_string_own_name(dtype)

def test_is_dtype_from_name(self, dtype, request):
Expand All @@ -505,6 +548,15 @@ def test_is_dtype_from_name(self, dtype, request):
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
elif pa.types.is_string(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=(
"Still support StringDtype('pyarrow') "
"over ArrowDtype(pa.string())"
),
)
)
super().test_is_dtype_from_name(dtype)

def test_construct_from_string(self, dtype, request):
Expand All @@ -516,6 +568,16 @@ def test_construct_from_string(self, dtype, request):
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
)
)
elif pa.types.is_string(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=(
"Still support StringDtype('pyarrow') "
"over ArrowDtype(pa.string())"
),
)
)
super().test_construct_from_string(dtype)

def test_construct_from_string_another_type_raises(self, dtype):
Expand All @@ -533,6 +595,8 @@ def test_get_common_dtype(self, dtype, request):
and (pa_dtype.unit != "ns" or pa_dtype.tz is not None)
)
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
Expand Down Expand Up @@ -592,7 +656,21 @@ def test_EA_types(self, engine, data, request):
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.",
)
)
super().test_EA_types(engine, data)
elif pa.types.is_binary(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
)
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
csv_output = df.to_csv(index=False, na_rep=np.nan)
if pa.types.is_binary(pa_dtype):
csv_output = BytesIO(csv_output)
else:
csv_output = StringIO(csv_output)
result = pd.read_csv(
csv_output, dtype={"with_dtype": str(data.dtype)}, engine=engine
)
expected = df
self.assert_frame_equal(result, expected)


class TestBaseUnaryOps(base.BaseUnaryOpsTests):
Expand Down Expand Up @@ -899,7 +977,11 @@ def test_arith_series_with_scalar(
or all_arithmetic_operators in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
)
if all_arithmetic_operators in {
if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")
elif all_arithmetic_operators in {
"__mod__",
"__rmod__",
}:
Expand Down Expand Up @@ -965,7 +1047,11 @@ def test_arith_frame_with_scalar(
or all_arithmetic_operators in ("__sub__", "__rsub__")
and pa.types.is_temporal(pa_dtype)
)
if all_arithmetic_operators in {
if all_arithmetic_operators == "__rmod__" and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
pytest.skip("Skip testing Python string formatting")
elif all_arithmetic_operators in {
"__mod__",
"__rmod__",
}:
Expand Down Expand Up @@ -1224,7 +1310,11 @@ def test_quantile(data, interpolation, quantile, request):
)
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_temporal(pa_dtype):
if (
pa.types.is_temporal(pa_dtype)
or pa.types.is_string(pa_dtype)
or pa.types.is_binary(pa_dtype)
):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
Expand Down