Skip to content

Commit 9b0d3bc

Browse files
committed
fixup tests!
1 parent ff5069d commit 9b0d3bc

File tree

4 files changed

+74
-6
lines changed

4 files changed

+74
-6
lines changed

pandas/core/interchange/buffer.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
self._buffer = arr.buffers()[1]
109109
self._length = len(arr)
110110
self._dlpack = arr.__dlpack__
111+
self._is_validity = is_validity
111112

112113
@property
113114
def bufsize(self) -> int:

pandas/core/interchange/column.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def describe_null(self):
198198
null_value = 1
199199
return column_null_dtype, null_value
200200
if isinstance(self._col.dtype, ArrowDtype):
201+
if all(
202+
chunk.buffers()[0] is None for chunk in self._col.array._pa_array.chunks
203+
):
204+
return ColumnNullType.NON_NULLABLE, None
201205
return ColumnNullType.USE_BITMASK, 0
202206
kind = self.dtype[0]
203207
try:
@@ -364,7 +368,7 @@ def _get_validity_buffer(self) -> tuple[PandasBuffer, Any]:
364368
arr = self._col.array
365369
dtype = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, Endianness.NATIVE)
366370
if all(chunk.buffers()[0] is None for chunk in arr._pa_array.chunks):
367-
return None, dtype
371+
return None
368372
buffer = PandasBufferPyarrow(
369373
arr._pa_array, is_validity=True, allow_copy=self._allow_copy
370374
)

pandas/core/interchange/from_dataframe.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,9 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]:
298298

299299
null_pos = None
300300
if null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK):
301-
assert buffers["validity"], "Validity buffers cannot be empty for masks"
302-
valid_buff, valid_dtype = buffers["validity"]
303-
if valid_buff is not None:
301+
validity = buffers["validity"]
302+
if validity is not None:
303+
valid_buff, valid_dtype = validity
304304
null_pos = buffer_to_ndarray(
305305
valid_buff, valid_dtype, offset=col.offset, length=col.size()
306306
)
@@ -520,13 +520,14 @@ def set_nulls(
520520
np.ndarray or pd.Series
521521
Data with the nulls being set.
522522
"""
523+
if validity is None:
524+
return data
523525
null_kind, sentinel_val = col.describe_null
524526
null_pos = None
525527

526528
if null_kind == ColumnNullType.USE_SENTINEL:
527529
null_pos = pd.Series(data) == sentinel_val
528530
elif null_kind in (ColumnNullType.USE_BITMASK, ColumnNullType.USE_BYTEMASK):
529-
assert validity, "Expected to have a validity buffer for the mask"
530531
valid_buff, valid_dtype = validity
531532
if valid_buff is not None:
532533
null_pos = buffer_to_ndarray(

pandas/tests/interchange/test_impl.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def test_non_str_names_w_duplicates():
463463
),
464464
],
465465
)
466-
def test_pandas_nullable_w_missing_values(
466+
def test_pandas_nullable_with_missing_values(
467467
data: list, dtype: str, expected_dtype: str
468468
) -> None:
469469
# https://github.com/pandas-dev/pandas/issues/57643
@@ -481,6 +481,68 @@ def test_pandas_nullable_w_missing_values(
481481
assert result[2].as_py() is None
482482

483483

484+
@pytest.mark.parametrize(
485+
("data", "dtype", "expected_dtype"),
486+
[
487+
([1, 2, 3], "Int64", "int64"),
488+
([1, 2, 3], "Int64[pyarrow]", "int64"),
489+
([1, 2, 3], "Int8", "int8"),
490+
([1, 2, 3], "Int8[pyarrow]", "int8"),
491+
(
492+
[1, 2, 3],
493+
"UInt64",
494+
"uint64",
495+
),
496+
(
497+
[1, 2, 3],
498+
"UInt64[pyarrow]",
499+
"uint64",
500+
),
501+
([1.0, 2.25, 5.0], "Float32", "float32"),
502+
([1.0, 2.25, 5.0], "Float32[pyarrow]", "float32"),
503+
([True, False, False], "boolean", "bool"),
504+
([True, False, False], "boolean[pyarrow]", "bool"),
505+
(["much ado", "about", "nothing"], "string[pyarrow_numpy]", "large_string"),
506+
(["much ado", "about", "nothing"], "string[pyarrow]", "large_string"),
507+
(
508+
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
509+
"timestamp[ns][pyarrow]",
510+
"timestamp[ns]",
511+
),
512+
(
513+
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
514+
"timestamp[us][pyarrow]",
515+
"timestamp[us]",
516+
),
517+
(
518+
[
519+
datetime(2020, 1, 1, tzinfo=timezone.utc),
520+
datetime(2020, 1, 2, tzinfo=timezone.utc),
521+
datetime(2020, 1, 3, tzinfo=timezone.utc),
522+
],
523+
"timestamp[us, Asia/Kathmandu][pyarrow]",
524+
"timestamp[us, tz=Asia/Kathmandu]",
525+
),
526+
],
527+
)
528+
def test_pandas_nullable_without_missing_values(
529+
data: list, dtype: str, expected_dtype: str
530+
) -> None:
531+
# https://github.com/pandas-dev/pandas/issues/57643
532+
pa = pytest.importorskip("pyarrow", "11.0.0")
533+
import pyarrow.interchange as pai
534+
535+
if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]":
536+
expected_dtype = pa.timestamp("us", "Asia/Kathmandu")
537+
538+
df = pd.DataFrame({"a": data}, dtype=dtype)
539+
result = pai.from_dataframe(df.__dataframe__())["a"]
540+
assert result.type == expected_dtype
541+
assert result[0].as_py() == data[0]
542+
assert result[1].as_py() == data[1]
543+
assert result[2].as_py() == data[2]
544+
545+
484546
def test_empty_dataframe():
485547
# https://github.com/pandas-dev/pandas/issues/56700
486548
df = pd.DataFrame({"a": []}, dtype="int8")

0 commit comments

Comments
 (0)