@@ -769,16 +769,51 @@ def get_chunks(self, n_chunks : Optional[int] = None) -> Iterable['_PandasDataFr
769
769
# Roundtrip testing
770
770
# -----------------
771
771
772
+ def assert_buffer_equal (buffer_dtype : Tuple [_PandasBuffer , Any ], pdcol :pd .Series ):
773
+ buf , dtype = buffer_dtype
774
+ pytest .raises (NotImplementedError , buf .__dlpack__ )
775
+ assert buf .__dlpack_device__ () == (1 , None )
776
+ # It seems that `bitwidth` is handled differently for `int` and `category`
777
+ # assert dtype[1] == pdcol.dtype.itemsize * 8, f"{dtype[1]} is not {pdcol.dtype.itemsize}"
778
+ # print(pdcol)
779
+ # if isinstance(pdcol, pd.CategoricalDtype):
780
+ # col = pdcol.values.codes
781
+ # else:
782
+ # col = pdcol
783
+
784
+ # assert dtype[1] == col.dtype.itemsize * 8, f"{dtype[1]} is not {col.dtype.itemsize * 8}"
785
+ # assert dtype[2] == col.dtype.str, f"{dtype[2]} is not {col.dtype.str}"
786
+
787
+
788
+ def assert_column_equal (col : _PandasColumn , pdcol :pd .Series ):
789
+ assert col .size == pdcol .size
790
+ assert col .offset == 0
791
+ assert col .null_count == pdcol .isnull ().sum ()
792
+ assert col .num_chunks () == 1
793
+ if col .dtype [0 ] != _DtypeKind .STRING :
794
+ pytest .raises (RuntimeError , col ._get_validity_buffer )
795
+ assert_buffer_equal (col ._get_data_buffer (), pdcol )
796
+
797
+ def assert_dataframe_equal (dfo : DataFrameObject , df :pd .DataFrame ):
798
+ assert dfo .num_columns () == len (df .columns )
799
+ assert dfo .num_rows () == len (df )
800
+ assert dfo .num_chunks () == 1
801
+ assert dfo .column_names () == list (df .columns )
802
+ for col in df .columns :
803
+ assert_column_equal (dfo .get_column_by_name (col ), df [col ])
804
+
772
805
def test_float_only ():
773
806
df = pd .DataFrame (data = dict (a = [1.5 , 2.5 , 3.5 ], b = [9.2 , 10.5 , 11.8 ]))
774
807
df2 = from_dataframe (df )
808
+ assert_dataframe_equal (df .__dataframe__ (), df )
775
809
tm .assert_frame_equal (df , df2 )
776
810
777
811
778
812
def test_mixed_intfloat ():
779
813
df = pd .DataFrame (data = dict (a = [1 , 2 , 3 ], b = [3 , 4 , 5 ],
780
814
c = [1.5 , 2.5 , 3.5 ], d = [9 , 10 , 11 ]))
781
815
df2 = from_dataframe (df )
816
+ assert_dataframe_equal (df .__dataframe__ (), df )
782
817
tm .assert_frame_equal (df , df2 )
783
818
784
819
@@ -787,6 +822,7 @@ def test_noncontiguous_columns():
787
822
df = pd .DataFrame (arr , columns = ['a' , 'b' , 'c' ])
788
823
assert df ['a' ].to_numpy ().strides == (24 ,)
789
824
df2 = from_dataframe (df ) # uses default of allow_copy=True
825
+ assert_dataframe_equal (df .__dataframe__ (), df )
790
826
tm .assert_frame_equal (df , df2 )
791
827
792
828
with pytest .raises (RuntimeError ):
@@ -807,6 +843,7 @@ def test_categorical_dtype():
807
843
assert col .describe_categorical == (False , True , {0 : 1 , 1 : 2 , 2 : 5 })
808
844
809
845
df2 = from_dataframe (df )
846
+ assert_dataframe_equal (df .__dataframe__ (), df )
810
847
tm .assert_frame_equal (df , df2 )
811
848
812
849
@@ -822,6 +859,8 @@ def test_string_dtype():
822
859
assert col .describe_null == (4 , 0 )
823
860
assert col .num_chunks () == 1
824
861
862
+ assert_dataframe_equal (df .__dataframe__ (), df )
863
+
825
864
def test_metadata ():
826
865
df = pd .DataFrame ({'A' : [1 , 2 , 3 , 4 ],'B' : [1 , 2 , 3 , 4 ]})
827
866
@@ -838,6 +877,7 @@ def test_metadata():
838
877
assert col_metadata [key ] == expected [key ]
839
878
840
879
df2 = from_dataframe (df )
880
+ assert_dataframe_equal (df .__dataframe__ (), df )
841
881
tm .assert_frame_equal (df , df2 )
842
882
843
883
0 commit comments