From 33b86adb3fb0b63f6142d04c19ad4b2217ca8c6f Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 13 Oct 2022 16:06:24 -0700 Subject: [PATCH 1/2] BUG: pickling subset of Arrow-backed data would serialize the entire data --- doc/source/whatsnew/v2.0.0.rst | 2 +- pandas/core/arrays/arrow/array.py | 11 +++++++++++ .../tests/arrays/string_/test_string_arrow.py | 18 ++++++++++++++++++ pandas/tests/extension/test_arrow.py | 17 +++++++++++++++++ 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 9d1e0c7485092..7d93ce3a58c1b 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -240,7 +240,7 @@ MultiIndex I/O ^^^ - Bug in :func:`read_sas` caused fragmentation of :class:`DataFrame` and raised :class:`.errors.PerformanceWarning` (:issue:`48595`) -- +- Bug when a pickling a subset PyArrow-backed data that would serialize the entire data instead of the subset (:issue:`42600`) Period ^^^^^^ diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index f6f933b1b9917..da7b90935beb2 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -373,6 +373,17 @@ def __pos__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: def __abs__(self: ArrowExtensionArrayT) -> ArrowExtensionArrayT: return type(self)(pc.abs_checked(self._data)) + # GH 42600: __getstate__/__setstate__ not necessary once + # https://issues.apache.org/jira/browse/ARROW-10739 is addressed + def __getstate__(self): + state = self.__dict__.copy() + state["_data"] = self._data.combine_chunks() + return state + + def __setstate__(self, state) -> None: + state["_data"] = pa.chunked_array(state["_data"]) + self.__dict__.update(state) + def _cmp_method(self, other, op): from pandas.arrays import BooleanArray diff --git a/pandas/tests/arrays/string_/test_string_arrow.py b/pandas/tests/arrays/string_/test_string_arrow.py index f43cf298857a0..ce11892394d20 100644 --- a/pandas/tests/arrays/string_/test_string_arrow.py +++ b/pandas/tests/arrays/string_/test_string_arrow.py @@ -1,3 +1,4 @@ +import pickle import re import numpy as np @@ -197,3 +198,20 @@ def test_setitem_invalid_indexer_raises(): with pytest.raises(ValueError, match=None): arr[[0, 1]] = ["foo", "bar", "baz"] + + +@skip_if_no_pyarrow +def test_pickle_roundtrip(): + # GH 42600 + expected = pd.Series(range(10), dtype="string[pyarrow]") + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + assert len(full_pickled) > len(sliced_pickled) + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 92e4dbaea4eea..0136e0ff0da5e 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -16,6 +16,7 @@ time, timedelta, ) +import pickle import numpy as np import pytest @@ -1765,3 +1766,19 @@ def test_is_bool_dtype(): result = s[data] expected = s[np.asarray(data)] tm.assert_series_equal(result, expected) + + +def test_pickle_roundtrip(data_missing): + # GH 42600 + expected = pd.Series(data_missing) + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + assert len(full_pickled) > len(sliced_pickled) + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced) From 34bac1e8c7d890e4163e7c11c9a91f75e8f5c27a Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 13 Oct 2022 16:08:32 -0700 Subject: [PATCH 2/2] Use data --- pandas/tests/extension/test_arrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 0136e0ff0da5e..57c97889cef64 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -1768,9 +1768,9 @@ def test_is_bool_dtype(): tm.assert_series_equal(result, expected) -def test_pickle_roundtrip(data_missing): +def test_pickle_roundtrip(data): # GH 42600 - expected = pd.Series(data_missing) + expected = pd.Series(data) expected_sliced = expected.head(2) full_pickled = pickle.dumps(expected) sliced_pickled = pickle.dumps(expected_sliced)