Skip to content

Commit ef4e30b

Browse files
Cornelius Riemenschneiderjorisvandenbossche
Cornelius Riemenschneider
authored andcommitted
ENH: Pass kwargs from read_parquet() to the underlying engines. (pandas-dev#18216)
This allows e.g. to specify filters for predicate pushdown to fastparquet.
1 parent 77f10f0 commit ef4e30b

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

doc/source/whatsnew/v0.21.1.txt

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ I/O
8686
- Bug in :func:`read_csv` for handling null values in index columns when specifying ``na_filter=False`` (:issue:`5239`)
8787
- Bug in :meth:`DataFrame.to_csv` when the table had ``MultiIndex`` columns, and a list of strings was passed in for ``header`` (:issue:`5539`)
8888
- :func:`read_parquet` now allows to specify the columns to read from a parquet file (:issue:`18154`)
89+
- :func:`read_parquet` now allows to specify kwargs which are passed to the respective engine (:issue:`18216`)
8990

9091
Plotting
9192
^^^^^^^^

pandas/io/parquet.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ def write(self, df, path, compression='snappy',
7676
table, path, compression=compression,
7777
coerce_timestamps=coerce_timestamps, **kwargs)
7878

79-
def read(self, path, columns=None):
79+
def read(self, path, columns=None, **kwargs):
8080
path, _, _ = get_filepath_or_buffer(path)
81-
return self.api.parquet.read_table(path, columns=columns).to_pandas()
81+
return self.api.parquet.read_table(path, columns=columns,
82+
**kwargs).to_pandas()
8283

8384

8485
class FastParquetImpl(object):
@@ -115,9 +116,9 @@ def write(self, df, path, compression='snappy', **kwargs):
115116
self.api.write(path, df,
116117
compression=compression, **kwargs)
117118

118-
def read(self, path, columns=None):
119+
def read(self, path, columns=None, **kwargs):
119120
path, _, _ = get_filepath_or_buffer(path)
120-
return self.api.ParquetFile(path).to_pandas(columns=columns)
121+
return self.api.ParquetFile(path).to_pandas(columns=columns, **kwargs)
121122

122123

123124
def to_parquet(df, path, engine='auto', compression='snappy', **kwargs):
@@ -175,7 +176,7 @@ def to_parquet(df, path, engine='auto', compression='snappy', **kwargs):
175176
if df.columns.inferred_type not in valid_types:
176177
raise ValueError("parquet must have string column names")
177178

178-
return impl.write(df, path, compression=compression)
179+
return impl.write(df, path, compression=compression, **kwargs)
179180

180181

181182
def read_parquet(path, engine='auto', columns=None, **kwargs):
@@ -205,4 +206,4 @@ def read_parquet(path, engine='auto', columns=None, **kwargs):
205206
"""
206207

207208
impl = get_engine(engine)
208-
return impl.read(path, columns=columns)
209+
return impl.read(path, columns=columns, **kwargs)

pandas/tests/io/test_parquet.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_options_py(df_compat, pa):
105105
with pd.option_context('io.parquet.engine', 'pyarrow'):
106106
df.to_parquet(path)
107107

108-
result = read_parquet(path, compression=None)
108+
result = read_parquet(path)
109109
tm.assert_frame_equal(result, df)
110110

111111

@@ -118,7 +118,7 @@ def test_options_fp(df_compat, fp):
118118
with pd.option_context('io.parquet.engine', 'fastparquet'):
119119
df.to_parquet(path, compression=None)
120120

121-
result = read_parquet(path, compression=None)
121+
result = read_parquet(path)
122122
tm.assert_frame_equal(result, df)
123123

124124

@@ -130,7 +130,7 @@ def test_options_auto(df_compat, fp, pa):
130130
with pd.option_context('io.parquet.engine', 'auto'):
131131
df.to_parquet(path)
132132

133-
result = read_parquet(path, compression=None)
133+
result = read_parquet(path)
134134
tm.assert_frame_equal(result, df)
135135

136136

@@ -162,7 +162,7 @@ def test_cross_engine_pa_fp(df_cross_compat, pa, fp):
162162
with tm.ensure_clean() as path:
163163
df.to_parquet(path, engine=pa, compression=None)
164164

165-
result = read_parquet(path, engine=fp, compression=None)
165+
result = read_parquet(path, engine=fp)
166166
tm.assert_frame_equal(result, df)
167167

168168

@@ -174,7 +174,7 @@ def test_cross_engine_fp_pa(df_cross_compat, pa, fp):
174174
with tm.ensure_clean() as path:
175175
df.to_parquet(path, engine=fp, compression=None)
176176

177-
result = read_parquet(path, engine=pa, compression=None)
177+
result = read_parquet(path, engine=pa)
178178
tm.assert_frame_equal(result, df)
179179

180180

@@ -188,19 +188,23 @@ def check_error_on_write(self, df, engine, exc):
188188
with tm.ensure_clean() as path:
189189
to_parquet(df, path, engine, compression=None)
190190

191-
def check_round_trip(self, df, engine, expected=None, **kwargs):
192-
191+
def check_round_trip(self, df, engine, expected=None,
192+
write_kwargs=None, read_kwargs=None):
193+
if write_kwargs is None:
194+
write_kwargs = {}
195+
if read_kwargs is None:
196+
read_kwargs = {}
193197
with tm.ensure_clean() as path:
194-
df.to_parquet(path, engine, **kwargs)
195-
result = read_parquet(path, engine, **kwargs)
198+
df.to_parquet(path, engine, **write_kwargs)
199+
result = read_parquet(path, engine, **read_kwargs)
196200

197201
if expected is None:
198202
expected = df
199203
tm.assert_frame_equal(result, expected)
200204

201205
# repeat
202-
to_parquet(df, path, engine, **kwargs)
203-
result = pd.read_parquet(path, engine, **kwargs)
206+
to_parquet(df, path, engine, **write_kwargs)
207+
result = pd.read_parquet(path, engine, **read_kwargs)
204208

205209
if expected is None:
206210
expected = df
@@ -222,7 +226,7 @@ def test_columns_dtypes(self, engine):
222226

223227
# unicode
224228
df.columns = [u'foo', u'bar']
225-
self.check_round_trip(df, engine, compression=None)
229+
self.check_round_trip(df, engine, write_kwargs={'compression': None})
226230

227231
def test_columns_dtypes_invalid(self, engine):
228232

@@ -246,7 +250,7 @@ def test_columns_dtypes_invalid(self, engine):
246250
def test_write_with_index(self, engine):
247251

248252
df = pd.DataFrame({'A': [1, 2, 3]})
249-
self.check_round_trip(df, engine, compression=None)
253+
self.check_round_trip(df, engine, write_kwargs={'compression': None})
250254

251255
# non-default index
252256
for index in [[2, 3, 4],
@@ -280,7 +284,8 @@ def test_compression(self, engine, compression):
280284
pytest.importorskip('brotli')
281285

282286
df = pd.DataFrame({'A': [1, 2, 3]})
283-
self.check_round_trip(df, engine, compression=compression)
287+
self.check_round_trip(df, engine,
288+
write_kwargs={'compression': compression})
284289

285290
def test_read_columns(self, engine):
286291
# GH18154
@@ -289,7 +294,8 @@ def test_read_columns(self, engine):
289294

290295
expected = pd.DataFrame({'string': list('abc')})
291296
self.check_round_trip(df, engine, expected=expected,
292-
compression=None, columns=["string"])
297+
write_kwargs={'compression': None},
298+
read_kwargs={'columns': ['string']})
293299

294300

295301
class TestParquetPyArrow(Base):
@@ -377,7 +383,7 @@ def test_basic(self, fp):
377383
'timedelta': pd.timedelta_range('1 day', periods=3),
378384
})
379385

380-
self.check_round_trip(df, fp, compression=None)
386+
self.check_round_trip(df, fp, write_kwargs={'compression': None})
381387

382388
@pytest.mark.skip(reason="not supported")
383389
def test_duplicate_columns(self, fp):
@@ -390,7 +396,8 @@ def test_duplicate_columns(self, fp):
390396
def test_bool_with_none(self, fp):
391397
df = pd.DataFrame({'a': [True, None, False]})
392398
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16')
393-
self.check_round_trip(df, fp, expected=expected, compression=None)
399+
self.check_round_trip(df, fp, expected=expected,
400+
write_kwargs={'compression': None})
394401

395402
def test_unsupported(self, fp):
396403

@@ -406,7 +413,7 @@ def test_categorical(self, fp):
406413
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"):
407414
pytest.skip("CategoricalDtype not supported for older fp")
408415
df = pd.DataFrame({'a': pd.Categorical(list('abc'))})
409-
self.check_round_trip(df, fp, compression=None)
416+
self.check_round_trip(df, fp, write_kwargs={'compression': None})
410417

411418
def test_datetime_tz(self, fp):
412419
# doesn't preserve tz
@@ -416,4 +423,13 @@ def test_datetime_tz(self, fp):
416423
# warns on the coercion
417424
with catch_warnings(record=True):
418425
self.check_round_trip(df, fp, df.astype('datetime64[ns]'),
419-
compression=None)
426+
write_kwargs={'compression': None})
427+
428+
def test_filter_row_groups(self, fp):
429+
d = {'a': list(range(0, 3))}
430+
df = pd.DataFrame(d)
431+
with tm.ensure_clean() as path:
432+
df.to_parquet(path, fp, compression=None,
433+
row_group_offsets=1)
434+
result = read_parquet(path, fp, filters=[('a', '==', 0)])
435+
assert len(result) == 1

0 commit comments

Comments
 (0)