Skip to content

Commit 8498cf1

Browse files
authored
More tests for the dataframe protocol (#49)
1 parent 60bedfe commit 8498cf1

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

protocol/pandas_implementation.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,16 +769,51 @@ def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable['_PandasDataFr
769769
# Roundtrip testing
770770
# -----------------
771771

772+
def assert_buffer_equal(buffer_dtype: Tuple[_PandasBuffer, Any], pdcol:pd.Series):
773+
buf, dtype = buffer_dtype
774+
pytest.raises(NotImplementedError, buf.__dlpack__)
775+
assert buf.__dlpack_device__() == (1, None)
776+
# It seems that `bitwidth` is handled differently for `int` and `category`
777+
# assert dtype[1] == pdcol.dtype.itemsize * 8, f"{dtype[1]} is not {pdcol.dtype.itemsize}"
778+
# print(pdcol)
779+
# if isinstance(pdcol, pd.CategoricalDtype):
780+
# col = pdcol.values.codes
781+
# else:
782+
# col = pdcol
783+
784+
# assert dtype[1] == col.dtype.itemsize * 8, f"{dtype[1]} is not {col.dtype.itemsize * 8}"
785+
# assert dtype[2] == col.dtype.str, f"{dtype[2]} is not {col.dtype.str}"
786+
787+
788+
def assert_column_equal(col: _PandasColumn, pdcol:pd.Series):
789+
assert col.size == pdcol.size
790+
assert col.offset == 0
791+
assert col.null_count == pdcol.isnull().sum()
792+
assert col.num_chunks() == 1
793+
if col.dtype[0] != _DtypeKind.STRING:
794+
pytest.raises(RuntimeError, col._get_validity_buffer)
795+
assert_buffer_equal(col._get_data_buffer(), pdcol)
796+
797+
def assert_dataframe_equal(dfo: DataFrameObject, df:pd.DataFrame):
798+
assert dfo.num_columns() == len(df.columns)
799+
assert dfo.num_rows() == len(df)
800+
assert dfo.num_chunks() == 1
801+
assert dfo.column_names() == list(df.columns)
802+
for col in df.columns:
803+
assert_column_equal(dfo.get_column_by_name(col), df[col])
804+
772805
def test_float_only():
773806
df = pd.DataFrame(data=dict(a=[1.5, 2.5, 3.5], b=[9.2, 10.5, 11.8]))
774807
df2 = from_dataframe(df)
808+
assert_dataframe_equal(df.__dataframe__(), df)
775809
tm.assert_frame_equal(df, df2)
776810

777811

778812
def test_mixed_intfloat():
779813
df = pd.DataFrame(data=dict(a=[1, 2, 3], b=[3, 4, 5],
780814
c=[1.5, 2.5, 3.5], d=[9, 10, 11]))
781815
df2 = from_dataframe(df)
816+
assert_dataframe_equal(df.__dataframe__(), df)
782817
tm.assert_frame_equal(df, df2)
783818

784819

@@ -787,6 +822,7 @@ def test_noncontiguous_columns():
787822
df = pd.DataFrame(arr, columns=['a', 'b', 'c'])
788823
assert df['a'].to_numpy().strides == (24,)
789824
df2 = from_dataframe(df) # uses default of allow_copy=True
825+
assert_dataframe_equal(df.__dataframe__(), df)
790826
tm.assert_frame_equal(df, df2)
791827

792828
with pytest.raises(RuntimeError):
@@ -807,6 +843,7 @@ def test_categorical_dtype():
807843
assert col.describe_categorical == (False, True, {0: 1, 1: 2, 2: 5})
808844

809845
df2 = from_dataframe(df)
846+
assert_dataframe_equal(df.__dataframe__(), df)
810847
tm.assert_frame_equal(df, df2)
811848

812849

@@ -822,6 +859,8 @@ def test_string_dtype():
822859
assert col.describe_null == (4, 0)
823860
assert col.num_chunks() == 1
824861

862+
assert_dataframe_equal(df.__dataframe__(), df)
863+
825864
def test_metadata():
826865
df = pd.DataFrame({'A': [1, 2, 3, 4],'B': [1, 2, 3, 4]})
827866

@@ -838,6 +877,7 @@ def test_metadata():
838877
assert col_metadata[key] == expected[key]
839878

840879
df2 = from_dataframe(df)
880+
assert_dataframe_equal(df.__dataframe__(), df)
841881
tm.assert_frame_equal(df, df2)
842882

843883

0 commit comments

Comments
 (0)