Skip to content

Commit 5fdb9c0

Browse files
maximvekslerjreback
authored andcommitted
Refactor test_parquet.py to use check_round_trip at module level (#19332)
1 parent 9fdac02 commit 5fdb9c0

File tree

1 file changed

+81
-87
lines changed

1 file changed

+81
-87
lines changed

pandas/tests/io/test_parquet.py

+81-87
Original file line numberDiff line numberDiff line change
@@ -110,48 +110,79 @@ def df_full():
110110
pd.Timestamp('20130103')]})
111111

112112

113-
def test_invalid_engine(df_compat):
113+
def check_round_trip(df, engine=None, path=None,
114+
write_kwargs=None, read_kwargs=None,
115+
expected=None, check_names=True,
116+
repeat=2):
117+
"""Verify parquet serializer and deserializer produce the same results.
118+
119+
Performs a pandas to disk and disk to pandas round trip,
120+
then compares the 2 resulting DataFrames to verify equality.
121+
122+
Parameters
123+
----------
124+
df: Dataframe
125+
engine: str, optional
126+
'pyarrow' or 'fastparquet'
127+
path: str, optional
128+
write_kwargs: dict of str:str, optional
129+
read_kwargs: dict of str:str, optional
130+
expected: DataFrame, optional
131+
Expected deserialization result, otherwise will be equal to `df`
132+
check_names: list of str, optional
133+
Closed set of column names to be compared
134+
repeat: int, optional
135+
How many times to repeat the test
136+
"""
137+
138+
write_kwargs = write_kwargs or {'compression': None}
139+
read_kwargs = read_kwargs or {}
140+
141+
if expected is None:
142+
expected = df
143+
144+
if engine:
145+
write_kwargs['engine'] = engine
146+
read_kwargs['engine'] = engine
147+
148+
def compare(repeat):
149+
for _ in range(repeat):
150+
df.to_parquet(path, **write_kwargs)
151+
actual = read_parquet(path, **read_kwargs)
152+
tm.assert_frame_equal(expected, actual,
153+
check_names=check_names)
154+
155+
if path is None:
156+
with tm.ensure_clean() as path:
157+
compare(repeat)
158+
else:
159+
compare(repeat)
114160

161+
162+
def test_invalid_engine(df_compat):
115163
with pytest.raises(ValueError):
116-
df_compat.to_parquet('foo', 'bar')
164+
check_round_trip(df_compat, 'foo', 'bar')
117165

118166

119167
def test_options_py(df_compat, pa):
120168
# use the set option
121169

122-
df = df_compat
123-
with tm.ensure_clean() as path:
124-
125-
with pd.option_context('io.parquet.engine', 'pyarrow'):
126-
df.to_parquet(path)
127-
128-
result = read_parquet(path)
129-
tm.assert_frame_equal(result, df)
170+
with pd.option_context('io.parquet.engine', 'pyarrow'):
171+
check_round_trip(df_compat)
130172

131173

132174
def test_options_fp(df_compat, fp):
133175
# use the set option
134176

135-
df = df_compat
136-
with tm.ensure_clean() as path:
137-
138-
with pd.option_context('io.parquet.engine', 'fastparquet'):
139-
df.to_parquet(path, compression=None)
140-
141-
result = read_parquet(path)
142-
tm.assert_frame_equal(result, df)
177+
with pd.option_context('io.parquet.engine', 'fastparquet'):
178+
check_round_trip(df_compat)
143179

144180

145181
def test_options_auto(df_compat, fp, pa):
182+
# use the set option
146183

147-
df = df_compat
148-
with tm.ensure_clean() as path:
149-
150-
with pd.option_context('io.parquet.engine', 'auto'):
151-
df.to_parquet(path)
152-
153-
result = read_parquet(path)
154-
tm.assert_frame_equal(result, df)
184+
with pd.option_context('io.parquet.engine', 'auto'):
185+
check_round_trip(df_compat)
155186

156187

157188
def test_options_get_engine(fp, pa):
@@ -228,53 +259,23 @@ def check_error_on_write(self, df, engine, exc):
228259
with tm.ensure_clean() as path:
229260
to_parquet(df, path, engine, compression=None)
230261

231-
def check_round_trip(self, df, engine, expected=None, path=None,
232-
write_kwargs=None, read_kwargs=None,
233-
check_names=True):
234-
235-
if write_kwargs is None:
236-
write_kwargs = {'compression': None}
237-
238-
if read_kwargs is None:
239-
read_kwargs = {}
240-
241-
if expected is None:
242-
expected = df
243-
244-
if path is None:
245-
with tm.ensure_clean() as path:
246-
check_round_trip_equals(df, path, engine,
247-
write_kwargs=write_kwargs,
248-
read_kwargs=read_kwargs,
249-
expected=expected,
250-
check_names=check_names)
251-
else:
252-
check_round_trip_equals(df, path, engine,
253-
write_kwargs=write_kwargs,
254-
read_kwargs=read_kwargs,
255-
expected=expected,
256-
check_names=check_names)
257-
258262

259263
class TestBasic(Base):
260264

261265
def test_error(self, engine):
262-
263266
for obj in [pd.Series([1, 2, 3]), 1, 'foo', pd.Timestamp('20130101'),
264267
np.array([1, 2, 3])]:
265268
self.check_error_on_write(obj, engine, ValueError)
266269

267270
def test_columns_dtypes(self, engine):
268-
269271
df = pd.DataFrame({'string': list('abc'),
270272
'int': list(range(1, 4))})
271273

272274
# unicode
273275
df.columns = [u'foo', u'bar']
274-
self.check_round_trip(df, engine)
276+
check_round_trip(df, engine)
275277

276278
def test_columns_dtypes_invalid(self, engine):
277-
278279
df = pd.DataFrame({'string': list('abc'),
279280
'int': list(range(1, 4))})
280281

@@ -302,17 +303,16 @@ def test_compression(self, engine, compression):
302303
pytest.importorskip('brotli')
303304

304305
df = pd.DataFrame({'A': [1, 2, 3]})
305-
self.check_round_trip(df, engine,
306-
write_kwargs={'compression': compression})
306+
check_round_trip(df, engine, write_kwargs={'compression': compression})
307307

308308
def test_read_columns(self, engine):
309309
# GH18154
310310
df = pd.DataFrame({'string': list('abc'),
311311
'int': list(range(1, 4))})
312312

313313
expected = pd.DataFrame({'string': list('abc')})
314-
self.check_round_trip(df, engine, expected=expected,
315-
read_kwargs={'columns': ['string']})
314+
check_round_trip(df, engine, expected=expected,
315+
read_kwargs={'columns': ['string']})
316316

317317
def test_write_index(self, engine):
318318
check_names = engine != 'fastparquet'
@@ -323,7 +323,7 @@ def test_write_index(self, engine):
323323
pytest.skip("pyarrow is < 0.7.0")
324324

325325
df = pd.DataFrame({'A': [1, 2, 3]})
326-
self.check_round_trip(df, engine)
326+
check_round_trip(df, engine)
327327

328328
indexes = [
329329
[2, 3, 4],
@@ -334,12 +334,12 @@ def test_write_index(self, engine):
334334
# non-default index
335335
for index in indexes:
336336
df.index = index
337-
self.check_round_trip(df, engine, check_names=check_names)
337+
check_round_trip(df, engine, check_names=check_names)
338338

339339
# index with meta-data
340340
df.index = [0, 1, 2]
341341
df.index.name = 'foo'
342-
self.check_round_trip(df, engine)
342+
check_round_trip(df, engine)
343343

344344
def test_write_multiindex(self, pa_ge_070):
345345
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
@@ -348,7 +348,7 @@ def test_write_multiindex(self, pa_ge_070):
348348
df = pd.DataFrame({'A': [1, 2, 3]})
349349
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
350350
df.index = index
351-
self.check_round_trip(df, engine)
351+
check_round_trip(df, engine)
352352

353353
def test_write_column_multiindex(self, engine):
354354
# column multi-index
@@ -357,7 +357,6 @@ def test_write_column_multiindex(self, engine):
357357
self.check_error_on_write(df, engine, ValueError)
358358

359359
def test_multiindex_with_columns(self, pa_ge_070):
360-
361360
engine = pa_ge_070
362361
dates = pd.date_range('01-Jan-2018', '01-Dec-2018', freq='MS')
363362
df = pd.DataFrame(np.random.randn(2 * len(dates), 3),
@@ -368,14 +367,10 @@ def test_multiindex_with_columns(self, pa_ge_070):
368367
index2 = index1.copy(names=None)
369368
for index in [index1, index2]:
370369
df.index = index
371-
with tm.ensure_clean() as path:
372-
df.to_parquet(path, engine)
373-
result = read_parquet(path, engine)
374-
expected = df
375-
tm.assert_frame_equal(result, expected)
376-
result = read_parquet(path, engine, columns=['A', 'B'])
377-
expected = df[['A', 'B']]
378-
tm.assert_frame_equal(result, expected)
370+
371+
check_round_trip(df, engine)
372+
check_round_trip(df, engine, read_kwargs={'columns': ['A', 'B']},
373+
expected=df[['A', 'B']])
379374

380375

381376
class TestParquetPyArrow(Base):
@@ -391,7 +386,7 @@ def test_basic(self, pa, df_full):
391386
tz='Europe/Brussels')
392387
df['bool_with_none'] = [True, None, True]
393388

394-
self.check_round_trip(df, pa)
389+
check_round_trip(df, pa)
395390

396391
@pytest.mark.xfail(reason="pyarrow fails on this (ARROW-1883)")
397392
def test_basic_subset_columns(self, pa, df_full):
@@ -402,8 +397,8 @@ def test_basic_subset_columns(self, pa, df_full):
402397
df['datetime_tz'] = pd.date_range('20130101', periods=3,
403398
tz='Europe/Brussels')
404399

405-
self.check_round_trip(df, pa, expected=df[['string', 'int']],
406-
read_kwargs={'columns': ['string', 'int']})
400+
check_round_trip(df, pa, expected=df[['string', 'int']],
401+
read_kwargs={'columns': ['string', 'int']})
407402

408403
def test_duplicate_columns(self, pa):
409404
# not currently able to handle duplicate columns
@@ -433,7 +428,7 @@ def test_categorical(self, pa_ge_070):
433428

434429
# de-serialized as object
435430
expected = df.assign(a=df.a.astype(object))
436-
self.check_round_trip(df, pa, expected)
431+
check_round_trip(df, pa, expected=expected)
437432

438433
def test_categorical_unsupported(self, pa_lt_070):
439434
pa = pa_lt_070
@@ -444,20 +439,19 @@ def test_categorical_unsupported(self, pa_lt_070):
444439

445440
def test_s3_roundtrip(self, df_compat, s3_resource, pa):
446441
# GH #19134
447-
self.check_round_trip(df_compat, pa,
448-
path='s3://pandas-test/pyarrow.parquet')
442+
check_round_trip(df_compat, pa,
443+
path='s3://pandas-test/pyarrow.parquet')
449444

450445

451446
class TestParquetFastParquet(Base):
452447

453448
def test_basic(self, fp, df_full):
454-
455449
df = df_full
456450

457451
# additional supported types for fastparquet
458452
df['timedelta'] = pd.timedelta_range('1 day', periods=3)
459453

460-
self.check_round_trip(df, fp)
454+
check_round_trip(df, fp)
461455

462456
@pytest.mark.skip(reason="not supported")
463457
def test_duplicate_columns(self, fp):
@@ -470,7 +464,7 @@ def test_duplicate_columns(self, fp):
470464
def test_bool_with_none(self, fp):
471465
df = pd.DataFrame({'a': [True, None, False]})
472466
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16')
473-
self.check_round_trip(df, fp, expected=expected)
467+
check_round_trip(df, fp, expected=expected)
474468

475469
def test_unsupported(self, fp):
476470

@@ -486,7 +480,7 @@ def test_categorical(self, fp):
486480
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"):
487481
pytest.skip("CategoricalDtype not supported for older fp")
488482
df = pd.DataFrame({'a': pd.Categorical(list('abc'))})
489-
self.check_round_trip(df, fp)
483+
check_round_trip(df, fp)
490484

491485
def test_datetime_tz(self, fp):
492486
# doesn't preserve tz
@@ -495,7 +489,7 @@ def test_datetime_tz(self, fp):
495489

496490
# warns on the coercion
497491
with catch_warnings(record=True):
498-
self.check_round_trip(df, fp, df.astype('datetime64[ns]'))
492+
check_round_trip(df, fp, expected=df.astype('datetime64[ns]'))
499493

500494
def test_filter_row_groups(self, fp):
501495
d = {'a': list(range(0, 3))}
@@ -508,5 +502,5 @@ def test_filter_row_groups(self, fp):
508502

509503
def test_s3_roundtrip(self, df_compat, s3_resource, fp):
510504
# GH #19134
511-
self.check_round_trip(df_compat, fp,
512-
path='s3://pandas-test/fastparquet.parquet')
505+
check_round_trip(df_compat, fp,
506+
path='s3://pandas-test/fastparquet.parquet')

0 commit comments

Comments
 (0)