Skip to content

Commit 52ff113

Browse files
committed
add private impl classes
pass thru kwargs to reader/writer
1 parent 05f5cfe commit 52ff113

File tree

2 files changed

+73
-56
lines changed

2 files changed

+73
-56
lines changed

pandas/core/frame.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1520,7 +1520,8 @@ def to_feather(self, fname):
15201520
from pandas.io.feather_format import to_feather
15211521
to_feather(self, fname)
15221522

1523-
def to_parquet(self, fname, engine, compression=None):
1523+
def to_parquet(self, fname, engine, compression=None,
1524+
**kwargs):
15241525
"""
15251526
write out the binary parquet for DataFrames
15261527
@@ -1534,10 +1535,12 @@ def to_parquet(self, fname, engine, compression=None):
15341535
supported are {'pyarrow', 'fastparquet'}
15351536
compression : str, optional
15361537
compression method, includes {'gzip', 'snappy', 'brotli'}
1538+
kwargs passed to the engine
15371539
15381540
"""
15391541
from pandas.io.parquet import to_parquet
1540-
to_parquet(self, fname, engine, compression=compression)
1542+
to_parquet(self, fname, engine,
1543+
compression=compression, **kwargs)
15411544

15421545
@Substitution(header='Write out column names. If a list of string is given, \
15431546
it is assumed to be aliases for the column names')

pandas/io/parquet.py

+68-54
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,78 @@
55
from pandas.compat import range
66

77

8-
def _try_import_pyarrow():
9-
# since pandas is a dependency of pyarrow
10-
# we need to import on first use
8+
def get_engine(engine):
9+
""" return our implementation """
1110

12-
try:
11+
if engine not in ['pyarrow', 'fastparquet']:
12+
raise ValueError("engine must be one of 'pyarrow', 'fastparquet'")
13+
14+
if engine == 'pyarrow':
15+
return PyArrowImpl()
16+
elif engine == 'fastparquet':
17+
return FastParquetImpl()
18+
19+
20+
class PyArrowImpl(object):
21+
22+
def __init__(self):
23+
# since pandas is a dependency of pyarrow
24+
# we need to import on first use
25+
26+
try:
27+
import pyarrow # noqa
28+
except ImportError:
29+
raise ImportError("pyarrow is required for parquet support\n\n"
30+
"you can install via conda\n"
31+
"conda install pyarrow -c conda-forge\n"
32+
"\nor via pip\n"
33+
"pip install pyarrow\n")
34+
35+
def write(self, df, path, compression=None, **kwargs):
1336
import pyarrow
14-
except ImportError:
15-
raise ImportError("pyarrow is required for parquet support\n\n"
16-
"you can install via conda\n"
17-
"conda install pyarrow -c conda-forge\n"
18-
"\nor via pip\n"
19-
"pip install pyarrow\n")
37+
from pyarrow import parquet as pq
2038

21-
return pyarrow
39+
table = pyarrow.Table.from_pandas(df)
40+
pq.write_table(table, path,
41+
compression=compression, **kwargs)
2242

43+
def read(self, path):
44+
import pyarrow
45+
return pyarrow.parquet.read_table(path).to_pandas()
2346

24-
def _try_import_fastparquet():
25-
# since pandas is a dependency of fastparquet
26-
# we need to import on first use
2747

28-
try:
29-
import fastparquet
30-
except ImportError:
31-
raise ImportError("fastparquet is required for parquet support\n\n"
32-
"you can install via conda\n"
33-
"conda install fastparquet -c conda-forge\n"
34-
"\nor via pip\n"
35-
"pip install fastparquet")
48+
class FastParquetImpl(object):
3649

37-
return fastparquet
50+
def __init__(self):
51+
# since pandas is a dependency of fastparquet
52+
# we need to import on first use
3853

54+
try:
55+
import fastparquet # noqa
56+
except ImportError:
57+
raise ImportError("fastparquet is required for parquet support\n\n"
58+
"you can install via conda\n"
59+
"conda install fastparquet -c conda-forge\n"
60+
"\nor via pip\n"
61+
"pip install fastparquet")
3962

40-
def _validate_engine(engine):
41-
if engine not in ['pyarrow', 'fastparquet']:
42-
raise ValueError("engine must be one of 'pyarrow', 'fastparquet'")
63+
def write(self, df, path, compression=None, **kwargs):
64+
import fastparquet
4365

66+
# thriftpy/protocol/compact.py:339:
67+
# DeprecationWarning: tostring() is deprecated.
68+
# Use tobytes() instead.
69+
with catch_warnings(record=True):
70+
fastparquet.write(path, df,
71+
compression=compression, **kwargs)
4472

45-
def to_parquet(df, path, engine, compression=None):
73+
def read(self, path):
74+
import fastparquet
75+
pf = fastparquet.ParquetFile(path)
76+
return pf.to_pandas()
77+
78+
79+
def to_parquet(df, path, engine, compression=None, **kwargs):
4680
"""
4781
Write a DataFrame to the pyarrow
4882
@@ -55,9 +89,10 @@ def to_parquet(df, path, engine, compression=None):
5589
supported are {'pyarrow', 'fastparquet'}
5690
compression : str, optional
5791
compression method, includes {'gzip', 'snappy', 'brotli'}
92+
kwargs are passed to the engine
5893
"""
5994

60-
_validate_engine(engine)
95+
impl = get_engine(engine)
6196

6297
if not isinstance(df, DataFrame):
6398
raise ValueError("to_parquet only support IO with DataFrames")
@@ -92,24 +127,10 @@ def to_parquet(df, path, engine, compression=None):
92127
if df.columns.inferred_type not in valid_types:
93128
raise ValueError("parquet must have string column names")
94129

95-
if engine == 'pyarrow':
96-
pyarrow = _try_import_pyarrow()
97-
from pyarrow import parquet as pq
98-
99-
table = pyarrow.Table.from_pandas(df)
100-
pq.write_table(table, path, compression=compression)
101-
102-
elif engine == 'fastparquet':
103-
fastparquet = _try_import_fastparquet()
104-
105-
# thriftpy/protocol/compact.py:339:
106-
# DeprecationWarning: tostring() is deprecated.
107-
# Use tobytes() instead.
108-
with catch_warnings(record=True):
109-
fastparquet.write(path, df, compression=compression)
130+
return impl.write(df, path, compression=compression)
110131

111132

112-
def read_parquet(path, engine):
133+
def read_parquet(path, engine, **kwargs):
113134
"""
114135
Load a parquet object from the file path
115136
@@ -121,20 +142,13 @@ def read_parquet(path, engine):
121142
File path
122143
engine : parquet engine
123144
supported are {'pyarrow', 'fastparquet'}
145+
kwargs are passed to the engine
124146
125147
Returns
126148
-------
127149
type of object stored in file
128150
129151
"""
130152

131-
_validate_engine(engine)
132-
133-
if engine == 'pyarrow':
134-
pyarrow = _try_import_pyarrow()
135-
return pyarrow.parquet.read_table(path).to_pandas()
136-
137-
elif engine == 'fastparquet':
138-
fastparquet = _try_import_fastparquet()
139-
pf = fastparquet.ParquetFile(path)
140-
return pf.to_pandas()
153+
impl = get_engine(engine)
154+
return impl.read(path)

0 commit comments

Comments
 (0)