From 0887dd21e2305b16f24f9c180b0e9e74a0bea86c Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 12 Feb 2020 16:04:31 +0100 Subject: [PATCH] Backport PR #31918: BUG: fix parquet roundtrip with unsigned integer dtypes --- doc/source/whatsnew/v1.0.2.rst | 2 ++ pandas/core/arrays/integer.py | 4 ++++ pandas/tests/arrays/test_integer.py | 17 +++++++++++++++-- pandas/tests/io/test_parquet.py | 13 ++++++++----- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/doc/source/whatsnew/v1.0.2.rst b/doc/source/whatsnew/v1.0.2.rst index 6d99668684a3b..44125ee30911f 100644 --- a/doc/source/whatsnew/v1.0.2.rst +++ b/doc/source/whatsnew/v1.0.2.rst @@ -33,6 +33,8 @@ Bug fixes **I/O** - Using ``pd.NA`` with :meth:`DataFrame.to_json` now correctly outputs a null value instead of an empty object (:issue:`31615`) +- Fixed bug in parquet roundtrip with nullable unsigned integer dtypes (:issue:`31896`). + **Experimental dtypes** diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index 9a0f5794e7607..96fdd8ee3c679 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -94,6 +94,10 @@ def __from_arrow__(self, array): import pyarrow from pandas.core.arrays._arrow_utils import pyarrow_array_to_numpy_and_mask + pyarrow_type = pyarrow.from_numpy_dtype(self.type) + if not array.type.equals(pyarrow_type): + array = array.cast(pyarrow_type) + if isinstance(array, pyarrow.Array): chunks = [array] else: diff --git a/pandas/tests/arrays/test_integer.py b/pandas/tests/arrays/test_integer.py index 857b793e9e9a8..2a6b6718cc149 100644 --- a/pandas/tests/arrays/test_integer.py +++ b/pandas/tests/arrays/test_integer.py @@ -1016,9 +1016,9 @@ def test_arrow_array(data): assert arr.equals(expected) -@td.skip_if_no("pyarrow", min_version="0.15.1.dev") +@td.skip_if_no("pyarrow", min_version="0.16.0") def test_arrow_roundtrip(data): - # roundtrip possible from arrow 1.0.0 + # roundtrip possible from arrow 0.16.0 import pyarrow as pa df = pd.DataFrame({"a": data}) @@ -1028,6 +1028,19 @@ def test_arrow_roundtrip(data): tm.assert_frame_equal(result, df) +@td.skip_if_no("pyarrow", min_version="0.16.0") +def test_arrow_from_arrow_uint(): + # https://github.com/pandas-dev/pandas/issues/31896 + # possible mismatch in types + import pyarrow as pa + + dtype = pd.UInt32Dtype() + result = dtype.__from_arrow__(pa.array([1, 2, 3, 4, None], type="int64")) + expected = pd.array([1, 2, 3, 4, None], dtype="UInt32") + + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize( "pandasmethname, kwargs", [ diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index d51c712ed5abd..7bcc354f53be0 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -533,25 +533,28 @@ def test_additional_extension_arrays(self, pa): df = pd.DataFrame( { "a": pd.Series([1, 2, 3], dtype="Int64"), - "b": pd.Series(["a", None, "c"], dtype="string"), + "b": pd.Series([1, 2, 3], dtype="UInt32"), + "c": pd.Series(["a", None, "c"], dtype="string"), } ) - if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"): + if LooseVersion(pyarrow.__version__) >= LooseVersion("0.16.0"): expected = df else: # de-serialized as plain int / object - expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object")) + expected = df.assign( + a=df.a.astype("int64"), b=df.b.astype("int64"), c=df.c.astype("object") + ) check_round_trip(df, pa, expected=expected) df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")}) - if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"): + if LooseVersion(pyarrow.__version__) >= LooseVersion("0.16.0"): expected = df else: # if missing values in integer, currently de-serialized as float expected = df.assign(a=df.a.astype("float64")) check_round_trip(df, pa, expected=expected) - @td.skip_if_no("pyarrow", min_version="0.15.1.dev") + @td.skip_if_no("pyarrow", min_version="0.16.0") def test_additional_extension_types(self, pa): # test additional ExtensionArrays that are supported through the # __arrow_array__ protocol + by defining a custom ExtensionType