Skip to content

Commit 2288135

Browse files
authored
ENH/TST: Add BaseParsinngTests tests for ArrowExtensionArray (#47536)
1 parent 2b1184d commit 2288135

File tree

2 files changed

+74
-12
lines changed

2 files changed

+74
-12
lines changed

pandas/core/arrays/arrow/array.py

+52-8
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@
6363
ArrowExtensionArrayT = TypeVar("ArrowExtensionArrayT", bound="ArrowExtensionArray")
6464

6565

66+
def to_pyarrow_type(
67+
dtype: ArrowDtype | pa.DataType | Dtype | None,
68+
) -> pa.DataType | None:
69+
"""
70+
Convert dtype to a pyarrow type instance.
71+
"""
72+
if isinstance(dtype, ArrowDtype):
73+
pa_dtype = dtype.pyarrow_dtype
74+
elif isinstance(dtype, pa.DataType):
75+
pa_dtype = dtype
76+
elif dtype:
77+
pa_dtype = pa.from_numpy_dtype(dtype)
78+
else:
79+
pa_dtype = None
80+
return pa_dtype
81+
82+
6683
class ArrowExtensionArray(OpsMixin, ExtensionArray):
6784
"""
6885
Base class for ExtensionArray backed by Arrow ChunkedArray.
@@ -89,13 +106,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
89106
"""
90107
Construct a new ExtensionArray from a sequence of scalars.
91108
"""
92-
if isinstance(dtype, ArrowDtype):
93-
pa_dtype = dtype.pyarrow_dtype
94-
elif dtype:
95-
pa_dtype = pa.from_numpy_dtype(dtype)
96-
else:
97-
pa_dtype = None
98-
109+
pa_dtype = to_pyarrow_type(dtype)
99110
if isinstance(scalars, cls):
100111
data = scalars._data
101112
if pa_dtype:
@@ -113,7 +124,40 @@ def _from_sequence_of_strings(
113124
"""
114125
Construct a new ExtensionArray from a sequence of strings.
115126
"""
116-
return cls._from_sequence(strings, dtype=dtype, copy=copy)
127+
pa_type = to_pyarrow_type(dtype)
128+
if pa.types.is_timestamp(pa_type):
129+
from pandas.core.tools.datetimes import to_datetime
130+
131+
scalars = to_datetime(strings, errors="raise")
132+
elif pa.types.is_date(pa_type):
133+
from pandas.core.tools.datetimes import to_datetime
134+
135+
scalars = to_datetime(strings, errors="raise").date
136+
elif pa.types.is_duration(pa_type):
137+
from pandas.core.tools.timedeltas import to_timedelta
138+
139+
scalars = to_timedelta(strings, errors="raise")
140+
elif pa.types.is_time(pa_type):
141+
from pandas.core.tools.times import to_time
142+
143+
# "coerce" to allow "null times" (None) to not raise
144+
scalars = to_time(strings, errors="coerce")
145+
elif pa.types.is_boolean(pa_type):
146+
from pandas.core.arrays import BooleanArray
147+
148+
scalars = BooleanArray._from_sequence_of_strings(strings).to_numpy()
149+
elif (
150+
pa.types.is_integer(pa_type)
151+
or pa.types.is_floating(pa_type)
152+
or pa.types.is_decimal(pa_type)
153+
):
154+
from pandas.core.tools.numeric import to_numeric
155+
156+
scalars = to_numeric(strings, errors="raise")
157+
else:
158+
# Let pyarrow try to infer or raise
159+
scalars = strings
160+
return cls._from_sequence(scalars, dtype=pa_type, copy=copy)
117161

118162
def __getitem__(self, item: PositionalIndexer):
119163
"""Select a subset of self.

pandas/tests/extension/test_arrow.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def test_setitem_loc_iloc_slice(self, data, using_array_manager, request):
712712
if pa_version_under2p0 and tz not in (None, "UTC"):
713713
request.node.add_marker(
714714
pytest.mark.xfail(
715-
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
715+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
716716
)
717717
)
718718
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
@@ -728,7 +728,7 @@ def test_setitem_slice_array(self, data, request):
728728
if pa_version_under2p0 and tz not in (None, "UTC"):
729729
request.node.add_marker(
730730
pytest.mark.xfail(
731-
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
731+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
732732
)
733733
)
734734
super().test_setitem_slice_array(data)
@@ -742,7 +742,7 @@ def test_setitem_with_expansion_dataframe_column(
742742
if pa_version_under2p0 and tz not in (None, "UTC") and not is_null_slice:
743743
request.node.add_marker(
744744
pytest.mark.xfail(
745-
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
745+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
746746
)
747747
)
748748
elif (
@@ -780,7 +780,7 @@ def test_setitem_frame_2d_values(self, data, using_array_manager, request):
780780
if pa_version_under2p0 and tz not in (None, "UTC"):
781781
request.node.add_marker(
782782
pytest.mark.xfail(
783-
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
783+
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
784784
)
785785
)
786786
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
@@ -796,6 +796,24 @@ def test_setitem_preserves_views(self, data):
796796
super().test_setitem_preserves_views(data)
797797

798798

799+
class TestBaseParsing(base.BaseParsingTests):
800+
@pytest.mark.parametrize("engine", ["c", "python"])
801+
def test_EA_types(self, engine, data, request):
802+
pa_dtype = data.dtype.pyarrow_dtype
803+
if pa.types.is_boolean(pa_dtype):
804+
request.node.add_marker(
805+
pytest.mark.xfail(raises=TypeError, reason="GH 47534")
806+
)
807+
elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
808+
request.node.add_marker(
809+
pytest.mark.xfail(
810+
raises=NotImplementedError,
811+
reason=f"Parameterized types with tz={pa_dtype.tz} not supported.",
812+
)
813+
)
814+
super().test_EA_types(engine, data)
815+
816+
799817
def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
800818
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
801819
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")

0 commit comments

Comments
 (0)