Skip to content

Commit a04754e

Browse files
authored
TST: ArrowExtensionArray with string and binary types (#49172)
1 parent 91de4cc commit a04754e

File tree

4 files changed

+114
-14
lines changed

4 files changed

+114
-14
lines changed

doc/source/whatsnew/v2.0.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ Conversion
219219
- Bug in :meth:`DataFrame.eval` incorrectly raising an ``AttributeError`` when there are negative values in function call (:issue:`46471`)
220220
- Bug in :meth:`Series.convert_dtypes` not converting dtype to nullable dtype when :class:`Series` contains ``NA`` and has dtype ``object`` (:issue:`48791`)
221221
- Bug where any :class:`ExtensionDtype` subclass with ``kind="M"`` would be interpreted as a timezone type (:issue:`34986`)
222-
-
222+
- Bug in :class:`.arrays.ArrowExtensionArray` that would raise ``NotImplementedError`` when passed a sequence of strings or binary (:issue:`49172`)
223223

224224
Strings
225225
^^^^^^^

pandas/_testing/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,11 @@
201201
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
202202
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
203203

204+
# pa.float16 doesn't seem supported
205+
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
204206
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
205-
STRING_PYARROW_DTYPES = [pa.string(), pa.utf8()]
207+
STRING_PYARROW_DTYPES = [pa.string()]
208+
BINARY_PYARROW_DTYPES = [pa.binary()]
206209

207210
TIME_PYARROW_DTYPES = [
208211
pa.time32("s"),
@@ -225,6 +228,8 @@
225228
ALL_PYARROW_DTYPES = (
226229
ALL_INT_PYARROW_DTYPES
227230
+ FLOAT_PYARROW_DTYPES
231+
+ STRING_PYARROW_DTYPES
232+
+ BINARY_PYARROW_DTYPES
228233
+ TIME_PYARROW_DTYPES
229234
+ DATE_PYARROW_DTYPES
230235
+ DATETIME_PYARROW_DTYPES

pandas/core/arrays/arrow/array.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,13 @@ def _from_sequence_of_strings(
220220
Construct a new ExtensionArray from a sequence of strings.
221221
"""
222222
pa_type = to_pyarrow_type(dtype)
223-
if pa_type is None:
224-
# Let pyarrow try to infer or raise
223+
if (
224+
pa_type is None
225+
or pa.types.is_binary(pa_type)
226+
or pa.types.is_string(pa_type)
227+
):
228+
# pa_type is None: Let pa.array infer
229+
# pa_type is string/binary: scalars already correct type
225230
scalars = strings
226231
elif pa.types.is_timestamp(pa_type):
227232
from pandas.core.tools.datetimes import to_datetime

pandas/tests/extension/test_arrow.py

+100-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
time,
1717
timedelta,
1818
)
19+
from io import (
20+
BytesIO,
21+
StringIO,
22+
)
1923

2024
import numpy as np
2125
import pytest
@@ -90,6 +94,10 @@ def data(dtype):
9094
+ [None]
9195
+ [time(0, 5), time(5, 0)]
9296
)
97+
elif pa.types.is_string(pa_dtype):
98+
data = ["a", "b"] * 4 + [None] + ["1", "2"] * 44 + [None] + ["!", ">"]
99+
elif pa.types.is_binary(pa_dtype):
100+
data = [b"a", b"b"] * 4 + [None] + [b"1", b"2"] * 44 + [None] + [b"!", b">"]
93101
else:
94102
raise NotImplementedError
95103
return pd.array(data, dtype=dtype)
@@ -155,6 +163,14 @@ def data_for_grouping(dtype):
155163
A = time(0, 0)
156164
B = time(0, 12)
157165
C = time(12, 12)
166+
elif pa.types.is_string(pa_dtype):
167+
A = "a"
168+
B = "b"
169+
C = "c"
170+
elif pa.types.is_binary(pa_dtype):
171+
A = b"a"
172+
B = b"b"
173+
C = b"c"
158174
else:
159175
raise NotImplementedError
160176
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
@@ -203,17 +219,30 @@ def na_value():
203219

204220

205221
class TestBaseCasting(base.BaseCastingTests):
206-
pass
222+
def test_astype_str(self, data, request):
223+
pa_dtype = data.dtype.pyarrow_dtype
224+
if pa.types.is_binary(pa_dtype):
225+
request.node.add_marker(
226+
pytest.mark.xfail(
227+
reason=f"For {pa_dtype} .astype(str) decodes.",
228+
)
229+
)
230+
super().test_astype_str(data)
207231

208232

209233
class TestConstructors(base.BaseConstructorsTests):
210234
def test_from_dtype(self, data, request):
211235
pa_dtype = data.dtype.pyarrow_dtype
212-
if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz:
236+
if (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz) or pa.types.is_string(
237+
pa_dtype
238+
):
239+
if pa.types.is_string(pa_dtype):
240+
reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
241+
else:
242+
reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
213243
request.node.add_marker(
214244
pytest.mark.xfail(
215-
raises=NotImplementedError,
216-
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
245+
reason=reason,
217246
)
218247
)
219248
super().test_from_dtype(data)
@@ -302,7 +331,7 @@ class TestGetitemTests(base.BaseGetitemTests):
302331
reason=(
303332
"data.dtype.type return pyarrow.DataType "
304333
"but this (intentionally) returns "
305-
"Python scalars or pd.Na"
334+
"Python scalars or pd.NA"
306335
)
307336
)
308337
def test_getitem_scalar(self, data):
@@ -361,7 +390,11 @@ def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
361390
or pa.types.is_boolean(pa_dtype)
362391
) and not (
363392
all_numeric_reductions in {"min", "max"}
364-
and (pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
393+
and (
394+
(pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype))
395+
or pa.types.is_string(pa_dtype)
396+
or pa.types.is_binary(pa_dtype)
397+
)
365398
):
366399
request.node.add_marker(xfail_mark)
367400
elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
@@ -494,6 +527,16 @@ def test_construct_from_string_own_name(self, dtype, request):
494527
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
495528
)
496529
)
530+
elif pa.types.is_string(pa_dtype):
531+
request.node.add_marker(
532+
pytest.mark.xfail(
533+
raises=TypeError,
534+
reason=(
535+
"Still support StringDtype('pyarrow') "
536+
"over ArrowDtype(pa.string())"
537+
),
538+
)
539+
)
497540
super().test_construct_from_string_own_name(dtype)
498541

499542
def test_is_dtype_from_name(self, dtype, request):
@@ -505,6 +548,15 @@ def test_is_dtype_from_name(self, dtype, request):
505548
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
506549
)
507550
)
551+
elif pa.types.is_string(pa_dtype):
552+
request.node.add_marker(
553+
pytest.mark.xfail(
554+
reason=(
555+
"Still support StringDtype('pyarrow') "
556+
"over ArrowDtype(pa.string())"
557+
),
558+
)
559+
)
508560
super().test_is_dtype_from_name(dtype)
509561

510562
def test_construct_from_string(self, dtype, request):
@@ -516,6 +568,16 @@ def test_construct_from_string(self, dtype, request):
516568
reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
517569
)
518570
)
571+
elif pa.types.is_string(pa_dtype):
572+
request.node.add_marker(
573+
pytest.mark.xfail(
574+
raises=TypeError,
575+
reason=(
576+
"Still support StringDtype('pyarrow') "
577+
"over ArrowDtype(pa.string())"
578+
),
579+
)
580+
)
519581
super().test_construct_from_string(dtype)
520582

521583
def test_construct_from_string_another_type_raises(self, dtype):
@@ -533,6 +595,8 @@ def test_get_common_dtype(self, dtype, request):
533595
and (pa_dtype.unit != "ns" or pa_dtype.tz is not None)
534596
)
535597
or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
598+
or pa.types.is_string(pa_dtype)
599+
or pa.types.is_binary(pa_dtype)
536600
):
537601
request.node.add_marker(
538602
pytest.mark.xfail(
@@ -592,7 +656,21 @@ def test_EA_types(self, engine, data, request):
592656
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.",
593657
)
594658
)
595-
super().test_EA_types(engine, data)
659+
elif pa.types.is_binary(pa_dtype):
660+
request.node.add_marker(
661+
pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
662+
)
663+
df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
664+
csv_output = df.to_csv(index=False, na_rep=np.nan)
665+
if pa.types.is_binary(pa_dtype):
666+
csv_output = BytesIO(csv_output)
667+
else:
668+
csv_output = StringIO(csv_output)
669+
result = pd.read_csv(
670+
csv_output, dtype={"with_dtype": str(data.dtype)}, engine=engine
671+
)
672+
expected = df
673+
self.assert_frame_equal(result, expected)
596674

597675

598676
class TestBaseUnaryOps(base.BaseUnaryOpsTests):
@@ -899,7 +977,11 @@ def test_arith_series_with_scalar(
899977
or all_arithmetic_operators in ("__sub__", "__rsub__")
900978
and pa.types.is_temporal(pa_dtype)
901979
)
902-
if all_arithmetic_operators in {
980+
if all_arithmetic_operators == "__rmod__" and (
981+
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
982+
):
983+
pytest.skip("Skip testing Python string formatting")
984+
elif all_arithmetic_operators in {
903985
"__mod__",
904986
"__rmod__",
905987
}:
@@ -965,7 +1047,11 @@ def test_arith_frame_with_scalar(
9651047
or all_arithmetic_operators in ("__sub__", "__rsub__")
9661048
and pa.types.is_temporal(pa_dtype)
9671049
)
968-
if all_arithmetic_operators in {
1050+
if all_arithmetic_operators == "__rmod__" and (
1051+
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
1052+
):
1053+
pytest.skip("Skip testing Python string formatting")
1054+
elif all_arithmetic_operators in {
9691055
"__mod__",
9701056
"__rmod__",
9711057
}:
@@ -1224,7 +1310,11 @@ def test_quantile(data, interpolation, quantile, request):
12241310
)
12251311
def test_mode(data_for_grouping, dropna, take_idx, exp_idx, request):
12261312
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
1227-
if pa.types.is_temporal(pa_dtype):
1313+
if (
1314+
pa.types.is_temporal(pa_dtype)
1315+
or pa.types.is_string(pa_dtype)
1316+
or pa.types.is_binary(pa_dtype)
1317+
):
12281318
request.node.add_marker(
12291319
pytest.mark.xfail(
12301320
raises=pa.ArrowNotImplementedError,

0 commit comments

Comments
 (0)