Skip to content

Commit a9ad632

Browse files
ENH: fix arrow roundtrip for ExtensionDtypes in absence of pandas metadata (#34275)
1 parent d916188 commit a9ad632

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

pandas/core/arrays/_arrow_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def __eq__(self, other):
7070
def __hash__(self):
7171
return hash((str(self), self.freq))
7272

73+
def to_pandas_dtype(self):
74+
import pandas as pd
75+
76+
return pd.PeriodDtype(freq=self.freq)
77+
7378
# register the type with a dummy instance
7479
_period_type = ArrowPeriodType("D")
7580
pyarrow.register_extension_type(_period_type)
@@ -119,6 +124,11 @@ def __eq__(self, other):
119124
def __hash__(self):
120125
return hash((str(self), str(self.subtype), self.closed))
121126

127+
def to_pandas_dtype(self):
128+
import pandas as pd
129+
130+
return pd.IntervalDtype(self.subtype.to_pandas_dtype())
131+
122132
# register the type with a dummy instance
123133
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
124134
pyarrow.register_extension_type(_interval_type)

pandas/tests/arrays/interval/test_interval.py

+23
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,26 @@ def test_arrow_table_roundtrip(breaks):
237237
result = table2.to_pandas()
238238
expected = pd.concat([df, df], ignore_index=True)
239239
tm.assert_frame_equal(result, expected)
240+
241+
242+
@pyarrow_skip
243+
@pytest.mark.parametrize(
244+
"breaks",
245+
[[0.0, 1.0, 2.0, 3.0], pd.date_range("2017", periods=4, freq="D")],
246+
ids=["float", "datetime64[ns]"],
247+
)
248+
def test_arrow_table_roundtrip_without_metadata(breaks):
249+
import pyarrow as pa
250+
251+
arr = IntervalArray.from_breaks(breaks)
252+
arr[1] = None
253+
df = pd.DataFrame({"a": arr})
254+
255+
table = pa.table(df)
256+
# remove the metadata
257+
table = table.replace_schema_metadata()
258+
assert table.schema.metadata is None
259+
260+
result = table.to_pandas()
261+
assert isinstance(result["a"].dtype, pd.IntervalDtype)
262+
tm.assert_frame_equal(result, df)

pandas/tests/arrays/test_period.py

+18
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,21 @@ def test_arrow_table_roundtrip():
414414
result = table2.to_pandas()
415415
expected = pd.concat([df, df], ignore_index=True)
416416
tm.assert_frame_equal(result, expected)
417+
418+
419+
@pyarrow_skip
420+
def test_arrow_table_roundtrip_without_metadata():
421+
import pyarrow as pa
422+
423+
arr = PeriodArray([1, 2, 3], freq="H")
424+
arr[1] = pd.NaT
425+
df = pd.DataFrame({"a": arr})
426+
427+
table = pa.table(df)
428+
# remove the metadata
429+
table = table.replace_schema_metadata()
430+
assert table.schema.metadata is None
431+
432+
result = table.to_pandas()
433+
assert isinstance(result["a"].dtype, PeriodDtype)
434+
tm.assert_frame_equal(result, df)

0 commit comments

Comments
 (0)