17
17
from pandas .compat .pyarrow import (
18
18
pa_version_under1p0 ,
19
19
pa_version_under2p0 ,
20
+ pa_version_under5p0 ,
20
21
)
21
22
import pandas .util ._test_decorators as td
22
23
@@ -222,6 +223,29 @@ def compare(repeat):
222
223
compare (repeat )
223
224
224
225
226
+ def check_partition_names (path , expected ):
227
+ """Check partitions of a parquet file are as expected.
228
+
229
+ Parameters
230
+ ----------
231
+ path: str
232
+ Path of the dataset.
233
+ expected: iterable of str
234
+ Expected partition names.
235
+ """
236
+ if pa_version_under5p0 :
237
+ import pyarrow .parquet as pq
238
+
239
+ dataset = pq .ParquetDataset (path , validate_schema = False )
240
+ assert len (dataset .partitions .partition_names ) == len (expected )
241
+ assert dataset .partitions .partition_names == set (expected )
242
+ else :
243
+ import pyarrow .dataset as ds
244
+
245
+ dataset = ds .dataset (path , partitioning = "hive" )
246
+ assert dataset .partitioning .schema .names == expected
247
+
248
+
225
249
def test_invalid_engine (df_compat ):
226
250
msg = "engine must be one of 'pyarrow', 'fastparquet'"
227
251
with pytest .raises (ValueError , match = msg ):
@@ -743,11 +767,7 @@ def test_partition_cols_supported(self, pa, df_full):
743
767
df = df_full
744
768
with tm .ensure_clean_dir () as path :
745
769
df .to_parquet (path , partition_cols = partition_cols , compression = None )
746
- import pyarrow .parquet as pq
747
-
748
- dataset = pq .ParquetDataset (path , validate_schema = False )
749
- assert len (dataset .partitions .partition_names ) == 2
750
- assert dataset .partitions .partition_names == set (partition_cols )
770
+ check_partition_names (path , partition_cols )
751
771
assert read_parquet (path ).shape == df .shape
752
772
753
773
def test_partition_cols_string (self , pa , df_full ):
@@ -757,11 +777,7 @@ def test_partition_cols_string(self, pa, df_full):
757
777
df = df_full
758
778
with tm .ensure_clean_dir () as path :
759
779
df .to_parquet (path , partition_cols = partition_cols , compression = None )
760
- import pyarrow .parquet as pq
761
-
762
- dataset = pq .ParquetDataset (path , validate_schema = False )
763
- assert len (dataset .partitions .partition_names ) == 1
764
- assert dataset .partitions .partition_names == set (partition_cols_list )
780
+ check_partition_names (path , partition_cols_list )
765
781
assert read_parquet (path ).shape == df .shape
766
782
767
783
@pytest .mark .parametrize ("path_type" , [str , pathlib .Path ])
0 commit comments