Skip to content

Commit dd9fda3

Browse files
MarcoGorellipmhatre1
authored andcommitted
BUG: Interchange object data buffer has the wrong dtype / from_dataframe incorrect (pandas-dev#57570)
string
1 parent bf60b12 commit dd9fda3

File tree

4 files changed

+68
-11
lines changed

4 files changed

+68
-11
lines changed

.pre-commit-config.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,11 @@ repos:
5353
- repo: https://github.com/pre-commit/pre-commit-hooks
5454
rev: v4.5.0
5555
hooks:
56-
- id: check-ast
5756
- id: check-case-conflict
5857
- id: check-toml
5958
- id: check-xml
6059
- id: check-yaml
6160
exclude: ^ci/meta.yaml$
62-
- id: debug-statements
6361
- id: end-of-file-fixer
6462
exclude: \.txt$
6563
- id: mixed-line-ending

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ Other
289289
- Bug in :func:`tseries.api.guess_datetime_format` would fail to infer time format when "%Y" == "%H%M" (:issue:`57452`)
290290
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
291291
- Bug in :meth:`DataFrame.where` where using a non-bool type array in the function would return a ``ValueError`` instead of a ``TypeError`` (:issue:`56330`)
292+
- Bug in Dataframe Interchange Protocol implementation was returning incorrect results for data buffers' associated dtype, for string and datetime columns (:issue:`54781`)
292293

293294
.. ***DO NOT USE THIS SECTION***
294295

pandas/core/interchange/column.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -278,20 +278,28 @@ def _get_data_buffer(
278278
"""
279279
Return the buffer containing the data and the buffer's associated dtype.
280280
"""
281-
if self.dtype[0] in (
282-
DtypeKind.INT,
283-
DtypeKind.UINT,
284-
DtypeKind.FLOAT,
285-
DtypeKind.BOOL,
286-
DtypeKind.DATETIME,
287-
):
281+
if self.dtype[0] == DtypeKind.DATETIME:
288282
# self.dtype[2] is an ArrowCTypes.TIMESTAMP where the tz will make
289283
# it longer than 4 characters
290-
if self.dtype[0] == DtypeKind.DATETIME and len(self.dtype[2]) > 4:
284+
if len(self.dtype[2]) > 4:
291285
np_arr = self._col.dt.tz_convert(None).to_numpy()
292286
else:
293287
np_arr = self._col.to_numpy()
294288
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
289+
dtype = (
290+
DtypeKind.INT,
291+
64,
292+
ArrowCTypes.INT64,
293+
Endianness.NATIVE,
294+
)
295+
elif self.dtype[0] in (
296+
DtypeKind.INT,
297+
DtypeKind.UINT,
298+
DtypeKind.FLOAT,
299+
DtypeKind.BOOL,
300+
):
301+
np_arr = self._col.to_numpy()
302+
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
295303
dtype = self.dtype
296304
elif self.dtype[0] == DtypeKind.CATEGORICAL:
297305
codes = self._col.values._codes
@@ -314,7 +322,12 @@ def _get_data_buffer(
314322
# Define the dtype for the returned buffer
315323
# TODO: this will need correcting
316324
# https://github.com/pandas-dev/pandas/issues/54781
317-
dtype = self.dtype
325+
dtype = (
326+
DtypeKind.UINT,
327+
8,
328+
ArrowCTypes.UINT8,
329+
Endianness.NATIVE,
330+
) # note: currently only support native endianness
318331
else:
319332
raise NotImplementedError(f"Data type {self._col.dtype} not handled yet")
320333

pandas/tests/interchange/test_impl.py

+45
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,48 @@ def test_empty_dataframe():
435435
result = pd.api.interchange.from_dataframe(dfi, allow_copy=False)
436436
expected = pd.DataFrame({"a": []}, dtype="int8")
437437
tm.assert_frame_equal(result, expected)
438+
439+
440+
@pytest.mark.parametrize(
441+
("data", "expected_dtype", "expected_buffer_dtype"),
442+
[
443+
(
444+
pd.Series(["a", "b", "a"], dtype="category"),
445+
(DtypeKind.CATEGORICAL, 8, "c", "="),
446+
(DtypeKind.INT, 8, "c", "|"),
447+
),
448+
(
449+
pd.Series(
450+
[datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)]
451+
),
452+
(DtypeKind.DATETIME, 64, "tsn:", "="),
453+
(DtypeKind.INT, 64, ArrowCTypes.INT64, "="),
454+
),
455+
(
456+
pd.Series(["a", "bc", None]),
457+
(DtypeKind.STRING, 8, ArrowCTypes.STRING, "="),
458+
(DtypeKind.UINT, 8, ArrowCTypes.UINT8, "="),
459+
),
460+
(
461+
pd.Series([1, 2, 3]),
462+
(DtypeKind.INT, 64, ArrowCTypes.INT64, "="),
463+
(DtypeKind.INT, 64, ArrowCTypes.INT64, "="),
464+
),
465+
(
466+
pd.Series([1.5, 2, 3]),
467+
(DtypeKind.FLOAT, 64, ArrowCTypes.FLOAT64, "="),
468+
(DtypeKind.FLOAT, 64, ArrowCTypes.FLOAT64, "="),
469+
),
470+
],
471+
)
472+
def test_buffer_dtype_categorical(
473+
data: pd.Series,
474+
expected_dtype: tuple[DtypeKind, int, str, str],
475+
expected_buffer_dtype: tuple[DtypeKind, int, str, str],
476+
) -> None:
477+
# https://github.com/pandas-dev/pandas/issues/54781
478+
df = pd.DataFrame({"data": data})
479+
dfi = df.__dataframe__()
480+
col = dfi.get_column_by_name("data")
481+
assert col.dtype == expected_dtype
482+
assert col.get_buffers()["data"][1] == expected_buffer_dtype

0 commit comments

Comments
 (0)