Skip to content

Commit a0d3b44

Browse files
author
MarcoGorelli
committed
round-trip categorical pyarrow
1 parent 11d75d8 commit a0d3b44

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Bug fixes
2626
~~~~~~~~~
2727
- Bug in :attr:`Series.dt.days` that would overflow ``int32`` number of days (:issue:`52391`)
2828
- Bug in :class:`arrays.DatetimeArray` constructor returning an incorrect unit when passed a non-nanosecond numpy datetime array (:issue:`52555`)
29+
- Bug in :func:`api.interchange.from_dataframe` was unnecessarily raising on-categorical dtypes (:issue:`49889`)
2930
- Bug in :func:`pandas.testing.assert_series_equal` where ``check_dtype=False`` would still raise for datetime or timedelta types with different resolutions (:issue:`52449`)
3031
- Bug in :func:`read_csv` casting PyArrow datetimes to NumPy when ``dtype_backend="pyarrow"`` and ``parse_dates`` is set causing a performance bottleneck in the process (:issue:`52546`)
3132
- Bug in :func:`to_datetime` and :func:`to_timedelta` when trying to convert numeric data with a :class:`ArrowDtype` (:issue:`52425`)

pandas/core/interchange/from_dataframe.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88

99
import pandas as pd
10-
from pandas.core.interchange.column import PandasColumn
1110
from pandas.core.interchange.dataframe_protocol import (
1211
Buffer,
1312
Column,
@@ -182,7 +181,7 @@ def categorical_column_to_series(col: Column) -> tuple[pd.Series, Any]:
182181

183182
cat_column = categorical["categories"]
184183
# for mypy/pyright
185-
assert isinstance(cat_column, PandasColumn), "categories must be a PandasColumn"
184+
assert hasattr(cat_column, "_col"), "categories must have a `.col` attribute"
186185
categories = np.array(cat_column._col)
187186
buffers = col.get_buffers()
188187

pandas/tests/interchange/test_impl.py

+14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33

44
import numpy as np
5+
import pyarrow as pa
56
import pytest
67

78
from pandas._libs.tslibs import iNaT
@@ -74,6 +75,19 @@ def test_categorical_dtype(data):
7475
tm.assert_frame_equal(df, from_dataframe(df.__dataframe__()))
7576

7677

78+
def test_categorical_pyarrow():
79+
# GH 49889
80+
arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"]
81+
table = pa.table({"weekday": pa.array(arr).dictionary_encode()})
82+
exchange_df = table.__dataframe__()
83+
result = from_dataframe(exchange_df)
84+
weekday = pd.Categorical(
85+
arr, categories=["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
86+
)
87+
expected = pd.DataFrame({"weekday": weekday})
88+
tm.assert_frame_equal(result, expected)
89+
90+
7791
@pytest.mark.parametrize(
7892
"data", [int_data, uint_data, float_data, bool_data, datetime_data]
7993
)

0 commit comments

Comments
 (0)