From 16ec4e667c473c966cc616a7961f5af748e67b57 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 20 May 2020 15:11:39 +0200 Subject: [PATCH] ENH: fix arrow roundtrip for ExtensionDtypes in absence of pandas metadata --- pandas/core/arrays/_arrow_utils.py | 10 ++++++++ pandas/tests/arrays/interval/test_interval.py | 23 +++++++++++++++++++ pandas/tests/arrays/test_period.py | 18 +++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/pandas/core/arrays/_arrow_utils.py b/pandas/core/arrays/_arrow_utils.py index e0d33bebeb421..4a33e0e841f7f 100644 --- a/pandas/core/arrays/_arrow_utils.py +++ b/pandas/core/arrays/_arrow_utils.py @@ -70,6 +70,11 @@ def __eq__(self, other): def __hash__(self): return hash((str(self), self.freq)) + def to_pandas_dtype(self): + import pandas as pd + + return pd.PeriodDtype(freq=self.freq) + # register the type with a dummy instance _period_type = ArrowPeriodType("D") pyarrow.register_extension_type(_period_type) @@ -119,6 +124,11 @@ def __eq__(self, other): def __hash__(self): return hash((str(self), str(self.subtype), self.closed)) + def to_pandas_dtype(self): + import pandas as pd + + return pd.IntervalDtype(self.subtype.to_pandas_dtype()) + # register the type with a dummy instance _interval_type = ArrowIntervalType(pyarrow.int64(), "left") pyarrow.register_extension_type(_interval_type) diff --git a/pandas/tests/arrays/interval/test_interval.py b/pandas/tests/arrays/interval/test_interval.py index fef11f0ff3bb2..d517eaaec68d2 100644 --- a/pandas/tests/arrays/interval/test_interval.py +++ b/pandas/tests/arrays/interval/test_interval.py @@ -237,3 +237,26 @@ def test_arrow_table_roundtrip(breaks): result = table2.to_pandas() expected = pd.concat([df, df], ignore_index=True) tm.assert_frame_equal(result, expected) + + +@pyarrow_skip +@pytest.mark.parametrize( + "breaks", + [[0.0, 1.0, 2.0, 3.0], pd.date_range("2017", periods=4, freq="D")], + ids=["float", "datetime64[ns]"], +) +def test_arrow_table_roundtrip_without_metadata(breaks): + import pyarrow as pa + + arr = IntervalArray.from_breaks(breaks) + arr[1] = None + df = pd.DataFrame({"a": arr}) + + table = pa.table(df) + # remove the metadata + table = table.replace_schema_metadata() + assert table.schema.metadata is None + + result = table.to_pandas() + assert isinstance(result["a"].dtype, pd.IntervalDtype) + tm.assert_frame_equal(result, df) diff --git a/pandas/tests/arrays/test_period.py b/pandas/tests/arrays/test_period.py index d3ced2f1b1f07..27e6334788284 100644 --- a/pandas/tests/arrays/test_period.py +++ b/pandas/tests/arrays/test_period.py @@ -414,3 +414,21 @@ def test_arrow_table_roundtrip(): result = table2.to_pandas() expected = pd.concat([df, df], ignore_index=True) tm.assert_frame_equal(result, expected) + + +@pyarrow_skip +def test_arrow_table_roundtrip_without_metadata(): + import pyarrow as pa + + arr = PeriodArray([1, 2, 3], freq="H") + arr[1] = pd.NaT + df = pd.DataFrame({"a": arr}) + + table = pa.table(df) + # remove the metadata + table = table.replace_schema_metadata() + assert table.schema.metadata is None + + result = table.to_pandas() + assert isinstance(result["a"].dtype, PeriodDtype) + tm.assert_frame_equal(result, df)