diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index e57ba92267855..0024ec0210696 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -265,6 +265,7 @@ MultiIndex I/O ^^^ - Bug in :func:`read_sas` caused fragmentation of :class:`DataFrame` and raised :class:`.errors.PerformanceWarning` (:issue:`48595`) +- Bug when a pickling a subset PyArrow-backed data that would serialize the entire data instead of the subset (:issue:`42600`) - Bug in :func:`read_csv` for a single-line csv with fewer columns than ``names`` raised :class:`.errors.ParserError` with ``engine="c"`` (:issue:`47566`) - diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index ea33bf58bacda..f18664915d015 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -351,6 +351,17 @@ def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: return type(self)(pc.abs_checked(self._data)) + # GH 42600: __getstate__/__setstate__ not necessary once + # https://issues.apache.org/jira/browse/ARROW-10739 is addressed + def __getstate__(self): + state = self.__dict__.copy() + state["_data"] = self._data.combine_chunks() + return state + + def __setstate__(self, state) -> None: + state["_data"] = pa.chunked_array(state["_data"]) + self.__dict__.update(state) + def _cmp_method(self, other, op): from pandas.arrays import BooleanArray diff --git a/pandas/tests/arrays/string_/test_string_arrow.py b/pandas/tests/arrays/string_/test_string_arrow.py index 8a6c2b0586a0c..4f0c4daa3c64f 100644 --- a/pandas/tests/arrays/string_/test_string_arrow.py +++ b/pandas/tests/arrays/string_/test_string_arrow.py @@ -1,3 +1,4 @@ +import pickle import re import numpy as np @@ -197,3 +198,20 @@ def test_setitem_invalid_indexer_raises(): with pytest.raises(ValueError, match=None): arr[[0, 1]] = ["foo", "bar", "baz"] + + +@skip_if_no_pyarrow +def test_pickle_roundtrip(): + # GH 42600 + expected = pd.Series(range(10), dtype="string[pyarrow]") + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + assert len(full_pickled) > len(sliced_pickled) + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 9a6b24583c525..8979c145a223c 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -20,6 +20,7 @@ BytesIO, StringIO, ) +import pickle import numpy as np import pytest @@ -1347,3 +1348,19 @@ def test_is_bool_dtype(): result = s[data] expected = s[np.asarray(data)] tm.assert_series_equal(result, expected) + + +def test_pickle_roundtrip(data): + # GH 42600 + expected = pd.Series(data) + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + assert len(full_pickled) > len(sliced_pickled) + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced)