From 974a4507a7314160c9c597c6ae1bdc568c19a006 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 6 Sep 2022 13:39:50 -0700 Subject: [PATCH 1/2] Backport PR #48264: BUG: ArrowExtensionArray._from_* accepts pyarrow arrays --- pandas/core/arrays/arrow/array.py | 20 +++--- pandas/core/tools/times.py | 25 ++++---- pandas/tests/extension/test_arrow.py | 94 ++++++++++++++++++++++++++++ pandas/tests/tools/test_to_time.py | 7 ++- 4 files changed, 126 insertions(+), 20 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 1f7939011a1f1..cfae5b4cae681 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -224,11 +224,13 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False): Construct a new ExtensionArray from a sequence of scalars. """ pa_dtype = to_pyarrow_type(dtype) - if isinstance(scalars, cls): - data = scalars._data + is_cls = isinstance(scalars, cls) + if is_cls or isinstance(scalars, (pa.Array, pa.ChunkedArray)): + if is_cls: + scalars = scalars._data if pa_dtype: - data = data.cast(pa_dtype) - return cls(data) + scalars = scalars.cast(pa_dtype) + return cls(scalars) else: return cls( pa.chunked_array(pa.array(scalars, type=pa_dtype, from_pandas=True)) @@ -242,7 +244,10 @@ def _from_sequence_of_strings( Construct a new ExtensionArray from a sequence of strings. """ pa_type = to_pyarrow_type(dtype) - if pa.types.is_timestamp(pa_type): + if pa_type is None: + # Let pyarrow try to infer or raise + scalars = strings + elif pa.types.is_timestamp(pa_type): from pandas.core.tools.datetimes import to_datetime scalars = to_datetime(strings, errors="raise") @@ -272,8 +277,9 @@ def _from_sequence_of_strings( scalars = to_numeric(strings, errors="raise") else: - # Let pyarrow try to infer or raise - scalars = strings + raise NotImplementedError( + f"Converting strings to {pa_type} is not implemented." + ) return cls._from_sequence(scalars, dtype=pa_type, copy=copy) def __getitem__(self, item: PositionalIndexer): diff --git a/pandas/core/tools/times.py b/pandas/core/tools/times.py index 030cee3f678f4..87667921bf75a 100644 --- a/pandas/core/tools/times.py +++ b/pandas/core/tools/times.py @@ -80,17 +80,20 @@ def _convert_listlike(arg, format): format_found = False for element in arg: time_object = None - for time_format in formats: - try: - time_object = datetime.strptime(element, time_format).time() - if not format_found: - # Put the found format in front - fmt = formats.pop(formats.index(time_format)) - formats.insert(0, fmt) - format_found = True - break - except (ValueError, TypeError): - continue + try: + time_object = time.fromisoformat(element) + except (ValueError, TypeError): + for time_format in formats: + try: + time_object = datetime.strptime(element, time_format).time() + if not format_found: + # Put the found format in front + fmt = formats.pop(formats.index(time_format)) + formats.insert(0, fmt) + format_found = True + break + except (ValueError, TypeError): + continue if time_object is not None: times.append(time_object) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 43c52ef8848e2..d853effab6ea3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -21,6 +21,8 @@ import pytest from pandas.compat import ( + is_ci_environment, + is_platform_windows, pa_version_under2p0, pa_version_under3p0, pa_version_under4p0, @@ -35,6 +37,8 @@ pa = pytest.importorskip("pyarrow", minversion="1.0.1") +from pandas.core.arrays.arrow.array import ArrowExtensionArray + from pandas.core.arrays.arrow.dtype import ArrowDtype # isort:skip @@ -222,6 +226,96 @@ def test_from_dtype(self, data, request): ) super().test_from_dtype(data) + def test_from_sequence_pa_array(self, data, request): + # https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784 + # data._data = pa.ChunkedArray + if pa_version_under3p0: + request.node.add_marker( + pytest.mark.xfail( + reason="ChunkedArray has no attribute combine_chunks", + ) + ) + result = type(data)._from_sequence(data._data) + tm.assert_extension_array_equal(result, data) + assert isinstance(result._data, pa.ChunkedArray) + + result = type(data)._from_sequence(data._data.combine_chunks()) + tm.assert_extension_array_equal(result, data) + assert isinstance(result._data, pa.ChunkedArray) + + def test_from_sequence_pa_array_notimplemented(self, request): + if pa_version_under6p0: + request.node.add_marker( + pytest.mark.xfail( + raises=AttributeError, + reason="month_day_nano_interval not implemented by pyarrow.", + ) + ) + with pytest.raises(NotImplementedError, match="Converting strings to"): + ArrowExtensionArray._from_sequence_of_strings( + ["12-1"], dtype=pa.month_day_nano_interval() + ) + + def test_from_sequence_of_strings_pa_array(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + if pa_version_under3p0: + request.node.add_marker( + pytest.mark.xfail( + reason="ChunkedArray has no attribute combine_chunks", + ) + ) + elif pa.types.is_time64(pa_dtype) and pa_dtype.equals("time64[ns]"): + request.node.add_marker( + pytest.mark.xfail( + reason="Nanosecond time parsing not supported.", + ) + ) + elif pa.types.is_duration(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"pyarrow doesn't support parsing {pa_dtype}", + ) + ) + elif pa.types.is_boolean(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + reason="Iterating over ChunkedArray[bool] returns PyArrow scalars.", + ) + ) + elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: + if pa_version_under7p0: + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"pyarrow doesn't support string cast from {pa_dtype}", + ) + ) + elif is_platform_windows() and is_ci_environment(): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + "TODO: Set ARROW_TIMEZONE_DATABASE environment variable " + "on CI to path to the tzdata for pyarrow." + ), + ) + ) + elif pa_version_under6p0 and pa.types.is_temporal(pa_dtype): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"pyarrow doesn't support string cast from {pa_dtype}", + ) + ) + pa_array = data._data.cast(pa.string()) + result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + + pa_array = pa_array.combine_chunks() + result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + @pytest.mark.xfail( raises=NotImplementedError, reason="pyarrow.ChunkedArray backing is 1D." diff --git a/pandas/tests/tools/test_to_time.py b/pandas/tests/tools/test_to_time.py index a8316e0f3970c..c80b1e080a1d1 100644 --- a/pandas/tests/tools/test_to_time.py +++ b/pandas/tests/tools/test_to_time.py @@ -4,6 +4,8 @@ import numpy as np import pytest +from pandas.compat import PY311 + from pandas import Series import pandas._testing as tm from pandas.core.tools.datetimes import to_time as to_time_alias @@ -40,8 +42,9 @@ def test_parsers_time(self, time_string): def test_odd_format(self): new_string = "14.15" msg = r"Cannot convert arg \['14\.15'\] to a time" - with pytest.raises(ValueError, match=msg): - to_time(new_string) + if not PY311: + with pytest.raises(ValueError, match=msg): + to_time(new_string) assert to_time(new_string, format="%H.%M") == time(14, 15) def test_arraylike(self): From 54084408099b3dfd946b9e7f7d28b00aa82d131a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 6 Sep 2022 15:06:46 -0700 Subject: [PATCH 2/2] Add missing import --- pandas/tests/extension/test_arrow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index d853effab6ea3..9100b67edbe69 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -27,6 +27,7 @@ pa_version_under3p0, pa_version_under4p0, pa_version_under6p0, + pa_version_under7p0, pa_version_under8p0, pa_version_under9p0, )