Skip to content

Commit 631e6a4

Browse files
committed
🏷️ typing
1 parent 458bd17 commit 631e6a4

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

pandas/core/interchange/column.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
)
47

58
import numpy as np
69

@@ -33,6 +36,9 @@
3336
dtype_to_arrow_c_fmt,
3437
)
3538

39+
if TYPE_CHECKING:
40+
from pandas.core.interchange.dataframe_protocol import Buffer
41+
3642
_NP_KINDS = {
3743
"i": DtypeKind.INT,
3844
"u": DtypeKind.UINT,
@@ -296,7 +302,7 @@ def get_buffers(self) -> ColumnBuffers:
296302

297303
def _get_data_buffer(
298304
self,
299-
) -> tuple[PandasBuffer, Any]: # Any is for self.dtype tuple
305+
) -> tuple[Buffer, tuple[DtypeKind, int, str, str]]:
300306
"""
301307
Return the buffer containing the data and the buffer's associated dtype.
302308
"""
@@ -307,7 +313,7 @@ def _get_data_buffer(
307313
np_arr = self._col.dt.tz_convert(None).to_numpy()
308314
else:
309315
np_arr = self._col.to_numpy()
310-
buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
316+
buffer: Buffer = PandasBuffer(np_arr, allow_copy=self._allow_copy)
311317
dtype = (
312318
DtypeKind.INT,
313319
64,
@@ -324,7 +330,9 @@ def _get_data_buffer(
324330
arr = self._col.array
325331
if isinstance(self._col.dtype, ArrowDtype):
326332
buffer = PandasBufferPyarrow(
327-
arr._pa_array, is_validity=False, allow_copy=self._allow_copy
333+
arr._pa_array, # type: ignore[attr-defined]
334+
is_validity=False,
335+
allow_copy=self._allow_copy,
328336
)
329337
if self.dtype[0] == DtypeKind.BOOL:
330338
dtype = (
@@ -371,7 +379,7 @@ def _get_data_buffer(
371379

372380
return buffer, dtype
373381

374-
def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]:
382+
def _get_validity_buffer(self) -> tuple[Buffer, Any] | None:
375383
"""
376384
Return the buffer containing the mask values indicating missing data and
377385
the buffer's associated dtype.
@@ -382,10 +390,15 @@ def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]:
382390
if isinstance(self._col.dtype, ArrowDtype):
383391
arr = self._col.array
384392
dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE)
385-
if all(chunk.buffers()[0] is None for chunk in arr._pa_array.chunks):
393+
if all(
394+
chunk.buffers()[0] is None
395+
for chunk in arr._pa_array.chunks # type: ignore[attr-defined]
396+
):
386397
return None
387-
buffer = PandasBufferPyarrow(
388-
arr._pa_array, is_validity=True, allow_copy=self._allow_copy
398+
buffer: Buffer = PandasBufferPyarrow(
399+
arr._pa_array, # type: ignore[attr-defined]
400+
is_validity=True,
401+
allow_copy=self._allow_copy,
389402
)
390403
return buffer, dtype
391404

pandas/tests/interchange/test_impl.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,9 @@ def test_string_validity_buffer() -> None:
552552
def test_string_validity_buffer_no_missing() -> None:
553553
# https://github.com/pandas-dev/pandas/issues/57762
554554
df = pd.DataFrame({"a": ["x", None]}, dtype="large_string[pyarrow]")
555-
result = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"][1]
555+
validity = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"]
556+
assert validity is not None
557+
result = validity[1]
556558
expected = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, "=")
557559
assert result == expected
558560

0 commit comments

Comments
 (0)