@@ -110,48 +110,79 @@ def df_full():
110
110
pd .Timestamp ('20130103' )]})
111
111
112
112
113
- def test_invalid_engine (df_compat ):
113
+ def check_round_trip (df , engine = None , path = None ,
114
+ write_kwargs = None , read_kwargs = None ,
115
+ expected = None , check_names = True ,
116
+ repeat = 2 ):
117
+ """Verify parquet serializer and deserializer produce the same results.
118
+
119
+ Performs a pandas to disk and disk to pandas round trip,
120
+ then compares the 2 resulting DataFrames to verify equality.
121
+
122
+ Parameters
123
+ ----------
124
+ df: Dataframe
125
+ engine: str, optional
126
+ 'pyarrow' or 'fastparquet'
127
+ path: str, optional
128
+ write_kwargs: dict of str:str, optional
129
+ read_kwargs: dict of str:str, optional
130
+ expected: DataFrame, optional
131
+ Expected deserialization result, otherwise will be equal to `df`
132
+ check_names: list of str, optional
133
+ Closed set of column names to be compared
134
+ repeat: int, optional
135
+ How many times to repeat the test
136
+ """
137
+
138
+ write_kwargs = write_kwargs or {'compression' : None }
139
+ read_kwargs = read_kwargs or {}
140
+
141
+ if expected is None :
142
+ expected = df
143
+
144
+ if engine :
145
+ write_kwargs ['engine' ] = engine
146
+ read_kwargs ['engine' ] = engine
147
+
148
+ def compare (repeat ):
149
+ for _ in range (repeat ):
150
+ df .to_parquet (path , ** write_kwargs )
151
+ actual = read_parquet (path , ** read_kwargs )
152
+ tm .assert_frame_equal (expected , actual ,
153
+ check_names = check_names )
154
+
155
+ if path is None :
156
+ with tm .ensure_clean () as path :
157
+ compare (repeat )
158
+ else :
159
+ compare (repeat )
114
160
161
+
162
+ def test_invalid_engine (df_compat ):
115
163
with pytest .raises (ValueError ):
116
- df_compat . to_parquet ( 'foo' , 'bar' )
164
+ check_round_trip ( df_compat , 'foo' , 'bar' )
117
165
118
166
119
167
def test_options_py (df_compat , pa ):
120
168
# use the set option
121
169
122
- df = df_compat
123
- with tm .ensure_clean () as path :
124
-
125
- with pd .option_context ('io.parquet.engine' , 'pyarrow' ):
126
- df .to_parquet (path )
127
-
128
- result = read_parquet (path )
129
- tm .assert_frame_equal (result , df )
170
+ with pd .option_context ('io.parquet.engine' , 'pyarrow' ):
171
+ check_round_trip (df_compat )
130
172
131
173
132
174
def test_options_fp (df_compat , fp ):
133
175
# use the set option
134
176
135
- df = df_compat
136
- with tm .ensure_clean () as path :
137
-
138
- with pd .option_context ('io.parquet.engine' , 'fastparquet' ):
139
- df .to_parquet (path , compression = None )
140
-
141
- result = read_parquet (path )
142
- tm .assert_frame_equal (result , df )
177
+ with pd .option_context ('io.parquet.engine' , 'fastparquet' ):
178
+ check_round_trip (df_compat )
143
179
144
180
145
181
def test_options_auto (df_compat , fp , pa ):
182
+ # use the set option
146
183
147
- df = df_compat
148
- with tm .ensure_clean () as path :
149
-
150
- with pd .option_context ('io.parquet.engine' , 'auto' ):
151
- df .to_parquet (path )
152
-
153
- result = read_parquet (path )
154
- tm .assert_frame_equal (result , df )
184
+ with pd .option_context ('io.parquet.engine' , 'auto' ):
185
+ check_round_trip (df_compat )
155
186
156
187
157
188
def test_options_get_engine (fp , pa ):
@@ -228,53 +259,23 @@ def check_error_on_write(self, df, engine, exc):
228
259
with tm .ensure_clean () as path :
229
260
to_parquet (df , path , engine , compression = None )
230
261
231
- def check_round_trip (self , df , engine , expected = None , path = None ,
232
- write_kwargs = None , read_kwargs = None ,
233
- check_names = True ):
234
-
235
- if write_kwargs is None :
236
- write_kwargs = {'compression' : None }
237
-
238
- if read_kwargs is None :
239
- read_kwargs = {}
240
-
241
- if expected is None :
242
- expected = df
243
-
244
- if path is None :
245
- with tm .ensure_clean () as path :
246
- check_round_trip_equals (df , path , engine ,
247
- write_kwargs = write_kwargs ,
248
- read_kwargs = read_kwargs ,
249
- expected = expected ,
250
- check_names = check_names )
251
- else :
252
- check_round_trip_equals (df , path , engine ,
253
- write_kwargs = write_kwargs ,
254
- read_kwargs = read_kwargs ,
255
- expected = expected ,
256
- check_names = check_names )
257
-
258
262
259
263
class TestBasic (Base ):
260
264
261
265
def test_error (self , engine ):
262
-
263
266
for obj in [pd .Series ([1 , 2 , 3 ]), 1 , 'foo' , pd .Timestamp ('20130101' ),
264
267
np .array ([1 , 2 , 3 ])]:
265
268
self .check_error_on_write (obj , engine , ValueError )
266
269
267
270
def test_columns_dtypes (self , engine ):
268
-
269
271
df = pd .DataFrame ({'string' : list ('abc' ),
270
272
'int' : list (range (1 , 4 ))})
271
273
272
274
# unicode
273
275
df .columns = [u'foo' , u'bar' ]
274
- self . check_round_trip (df , engine )
276
+ check_round_trip (df , engine )
275
277
276
278
def test_columns_dtypes_invalid (self , engine ):
277
-
278
279
df = pd .DataFrame ({'string' : list ('abc' ),
279
280
'int' : list (range (1 , 4 ))})
280
281
@@ -302,17 +303,16 @@ def test_compression(self, engine, compression):
302
303
pytest .importorskip ('brotli' )
303
304
304
305
df = pd .DataFrame ({'A' : [1 , 2 , 3 ]})
305
- self .check_round_trip (df , engine ,
306
- write_kwargs = {'compression' : compression })
306
+ check_round_trip (df , engine , write_kwargs = {'compression' : compression })
307
307
308
308
def test_read_columns (self , engine ):
309
309
# GH18154
310
310
df = pd .DataFrame ({'string' : list ('abc' ),
311
311
'int' : list (range (1 , 4 ))})
312
312
313
313
expected = pd .DataFrame ({'string' : list ('abc' )})
314
- self . check_round_trip (df , engine , expected = expected ,
315
- read_kwargs = {'columns' : ['string' ]})
314
+ check_round_trip (df , engine , expected = expected ,
315
+ read_kwargs = {'columns' : ['string' ]})
316
316
317
317
def test_write_index (self , engine ):
318
318
check_names = engine != 'fastparquet'
@@ -323,7 +323,7 @@ def test_write_index(self, engine):
323
323
pytest .skip ("pyarrow is < 0.7.0" )
324
324
325
325
df = pd .DataFrame ({'A' : [1 , 2 , 3 ]})
326
- self . check_round_trip (df , engine )
326
+ check_round_trip (df , engine )
327
327
328
328
indexes = [
329
329
[2 , 3 , 4 ],
@@ -334,12 +334,12 @@ def test_write_index(self, engine):
334
334
# non-default index
335
335
for index in indexes :
336
336
df .index = index
337
- self . check_round_trip (df , engine , check_names = check_names )
337
+ check_round_trip (df , engine , check_names = check_names )
338
338
339
339
# index with meta-data
340
340
df .index = [0 , 1 , 2 ]
341
341
df .index .name = 'foo'
342
- self . check_round_trip (df , engine )
342
+ check_round_trip (df , engine )
343
343
344
344
def test_write_multiindex (self , pa_ge_070 ):
345
345
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
@@ -348,7 +348,7 @@ def test_write_multiindex(self, pa_ge_070):
348
348
df = pd .DataFrame ({'A' : [1 , 2 , 3 ]})
349
349
index = pd .MultiIndex .from_tuples ([('a' , 1 ), ('a' , 2 ), ('b' , 1 )])
350
350
df .index = index
351
- self . check_round_trip (df , engine )
351
+ check_round_trip (df , engine )
352
352
353
353
def test_write_column_multiindex (self , engine ):
354
354
# column multi-index
@@ -357,7 +357,6 @@ def test_write_column_multiindex(self, engine):
357
357
self .check_error_on_write (df , engine , ValueError )
358
358
359
359
def test_multiindex_with_columns (self , pa_ge_070 ):
360
-
361
360
engine = pa_ge_070
362
361
dates = pd .date_range ('01-Jan-2018' , '01-Dec-2018' , freq = 'MS' )
363
362
df = pd .DataFrame (np .random .randn (2 * len (dates ), 3 ),
@@ -368,14 +367,10 @@ def test_multiindex_with_columns(self, pa_ge_070):
368
367
index2 = index1 .copy (names = None )
369
368
for index in [index1 , index2 ]:
370
369
df .index = index
371
- with tm .ensure_clean () as path :
372
- df .to_parquet (path , engine )
373
- result = read_parquet (path , engine )
374
- expected = df
375
- tm .assert_frame_equal (result , expected )
376
- result = read_parquet (path , engine , columns = ['A' , 'B' ])
377
- expected = df [['A' , 'B' ]]
378
- tm .assert_frame_equal (result , expected )
370
+
371
+ check_round_trip (df , engine )
372
+ check_round_trip (df , engine , read_kwargs = {'columns' : ['A' , 'B' ]},
373
+ expected = df [['A' , 'B' ]])
379
374
380
375
381
376
class TestParquetPyArrow (Base ):
@@ -391,7 +386,7 @@ def test_basic(self, pa, df_full):
391
386
tz = 'Europe/Brussels' )
392
387
df ['bool_with_none' ] = [True , None , True ]
393
388
394
- self . check_round_trip (df , pa )
389
+ check_round_trip (df , pa )
395
390
396
391
@pytest .mark .xfail (reason = "pyarrow fails on this (ARROW-1883)" )
397
392
def test_basic_subset_columns (self , pa , df_full ):
@@ -402,8 +397,8 @@ def test_basic_subset_columns(self, pa, df_full):
402
397
df ['datetime_tz' ] = pd .date_range ('20130101' , periods = 3 ,
403
398
tz = 'Europe/Brussels' )
404
399
405
- self . check_round_trip (df , pa , expected = df [['string' , 'int' ]],
406
- read_kwargs = {'columns' : ['string' , 'int' ]})
400
+ check_round_trip (df , pa , expected = df [['string' , 'int' ]],
401
+ read_kwargs = {'columns' : ['string' , 'int' ]})
407
402
408
403
def test_duplicate_columns (self , pa ):
409
404
# not currently able to handle duplicate columns
@@ -433,7 +428,7 @@ def test_categorical(self, pa_ge_070):
433
428
434
429
# de-serialized as object
435
430
expected = df .assign (a = df .a .astype (object ))
436
- self . check_round_trip (df , pa , expected )
431
+ check_round_trip (df , pa , expected = expected )
437
432
438
433
def test_categorical_unsupported (self , pa_lt_070 ):
439
434
pa = pa_lt_070
@@ -444,20 +439,19 @@ def test_categorical_unsupported(self, pa_lt_070):
444
439
445
440
def test_s3_roundtrip (self , df_compat , s3_resource , pa ):
446
441
# GH #19134
447
- self . check_round_trip (df_compat , pa ,
448
- path = 's3://pandas-test/pyarrow.parquet' )
442
+ check_round_trip (df_compat , pa ,
443
+ path = 's3://pandas-test/pyarrow.parquet' )
449
444
450
445
451
446
class TestParquetFastParquet (Base ):
452
447
453
448
def test_basic (self , fp , df_full ):
454
-
455
449
df = df_full
456
450
457
451
# additional supported types for fastparquet
458
452
df ['timedelta' ] = pd .timedelta_range ('1 day' , periods = 3 )
459
453
460
- self . check_round_trip (df , fp )
454
+ check_round_trip (df , fp )
461
455
462
456
@pytest .mark .skip (reason = "not supported" )
463
457
def test_duplicate_columns (self , fp ):
@@ -470,7 +464,7 @@ def test_duplicate_columns(self, fp):
470
464
def test_bool_with_none (self , fp ):
471
465
df = pd .DataFrame ({'a' : [True , None , False ]})
472
466
expected = pd .DataFrame ({'a' : [1.0 , np .nan , 0.0 ]}, dtype = 'float16' )
473
- self . check_round_trip (df , fp , expected = expected )
467
+ check_round_trip (df , fp , expected = expected )
474
468
475
469
def test_unsupported (self , fp ):
476
470
@@ -486,7 +480,7 @@ def test_categorical(self, fp):
486
480
if LooseVersion (fastparquet .__version__ ) < LooseVersion ("0.1.3" ):
487
481
pytest .skip ("CategoricalDtype not supported for older fp" )
488
482
df = pd .DataFrame ({'a' : pd .Categorical (list ('abc' ))})
489
- self . check_round_trip (df , fp )
483
+ check_round_trip (df , fp )
490
484
491
485
def test_datetime_tz (self , fp ):
492
486
# doesn't preserve tz
@@ -495,7 +489,7 @@ def test_datetime_tz(self, fp):
495
489
496
490
# warns on the coercion
497
491
with catch_warnings (record = True ):
498
- self . check_round_trip (df , fp , df .astype ('datetime64[ns]' ))
492
+ check_round_trip (df , fp , expected = df .astype ('datetime64[ns]' ))
499
493
500
494
def test_filter_row_groups (self , fp ):
501
495
d = {'a' : list (range (0 , 3 ))}
@@ -508,5 +502,5 @@ def test_filter_row_groups(self, fp):
508
502
509
503
def test_s3_roundtrip (self , df_compat , s3_resource , fp ):
510
504
# GH #19134
511
- self . check_round_trip (df_compat , fp ,
512
- path = 's3://pandas-test/fastparquet.parquet' )
505
+ check_round_trip (df_compat , fp ,
506
+ path = 's3://pandas-test/fastparquet.parquet' )
0 commit comments