Skip to content

Commit 0f7b526

Browse files
fix test for older versions of pyarrow
1 parent 1acb00c commit 0f7b526

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

pandas/io/parquet.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,12 @@ def __init__(self):
8686
"\nor via pip\n"
8787
"pip install -U pyarrow\n"
8888
)
89+
90+
self._pyarrow_lt_060 = (
91+
LooseVersion(pyarrow.__version__) < LooseVersion('0.6.0'))
8992
self._pyarrow_lt_070 = (
90-
LooseVersion(pyarrow.__version__) < LooseVersion('0.7.0')
91-
)
93+
LooseVersion(pyarrow.__version__) < LooseVersion('0.7.0'))
94+
9295
self.api = pyarrow
9396

9497
def write(self, df, path, compression='snappy',
@@ -99,17 +102,23 @@ def write(self, df, path, compression='snappy',
99102
df, path, compression, coerce_timestamps, **kwargs
100103
)
101104
path, _, _ = get_filepath_or_buffer(path)
102-
table = self.api.Table.from_pandas(df)
103-
self.api.parquet.write_table(
104-
table, path, compression=compression,
105-
coerce_timestamps=coerce_timestamps, **kwargs)
105+
106+
if self._pyarrow_lt_060:
107+
table = self.api.Table.from_pandas(df, timestamps_to_ms=True)
108+
self.api.parquet.write_table(
109+
table, path, compression=compression, **kwargs)
110+
111+
else:
112+
table = self.api.Table.from_pandas(df)
113+
self.api.parquet.write_table(
114+
table, path, compression=compression,
115+
coerce_timestamps=coerce_timestamps, **kwargs)
106116

107117
def read(self, path, columns=None, **kwargs):
108118
path, _, _ = get_filepath_or_buffer(path)
109119
parquet_file = self.api.parquet.ParquetFile(path)
110120
if self._pyarrow_lt_070:
111-
parquet_file.path = path
112-
return self._read_lt_070(parquet_file, columns, **kwargs)
121+
return self._read_lt_070(path, parquet_file, columns, **kwargs)
113122
kwargs['use_pandas_metadata'] = True
114123
return parquet_file.read(columns=columns, **kwargs).to_pandas()
115124

@@ -143,17 +152,17 @@ def _validate_write_lt_070(self, df, path, compression='snappy',
143152
"on a default index"
144153
)
145154

146-
def _read_lt_070(self, parquet_file, columns, **kwargs):
155+
def _read_lt_070(self, path, parquet_file, columns, **kwargs):
147156
# Compatibility shim for pyarrow < 0.7.0
148157
# TODO: Remove in pandas 0.22.0
149158
from itertools import chain
150159
import json
151160
if columns is not None:
152-
metadata = json.loads(parquet_file.metadata.metadata[b'pandas'])
161+
metadata = json.loads(
162+
parquet_file.metadata.metadata[b'pandas'].decode('utf-8'))
153163
columns = set(chain(columns, metadata['index_columns']))
154164
kwargs['columns'] = columns
155-
kwargs['path'] = parquet_file.path
156-
return self.api.parquet.read_table(**kwargs).to_pandas()
165+
return self.api.parquet.read_table(path, **kwargs).to_pandas()
157166

158167

159168
class FastParquetImpl(BaseImpl):

pandas/tests/io/test_parquet.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,14 @@ def test_read_columns(self, engine):
296296
write_kwargs={'compression': None},
297297
read_kwargs={'columns': ['string']})
298298

299-
def test_write_with_index(self, engine):
299+
def test_write_index(self, engine):
300300
check_names = engine != 'fastparquet'
301301

302+
if engine == 'pyarrow':
303+
import pyarrow
304+
if LooseVersion(pyarrow.__version__) < LooseVersion('0.7.0'):
305+
pytest.skip("pyarrow is < 0.7.0")
306+
302307
df = pd.DataFrame({'A': [1, 2, 3]})
303308
self.check_round_trip(df, engine, write_kwargs={'compression': None})
304309

@@ -314,34 +319,31 @@ def test_write_with_index(self, engine):
314319
self.check_round_trip(
315320
df, engine,
316321
write_kwargs={'compression': None},
317-
check_names=check_names,
318-
)
319-
if engine != 'fastparquet':
320-
# Not suppoprted in fastparquet as of 0.1.3
321-
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
322-
df.index = index
323-
self.check_round_trip(
324-
df, engine,
325-
write_kwargs={'compression': None},
326-
)
322+
check_names=check_names)
323+
327324
# index with meta-data
328325
df.index = [0, 1, 2]
329326
df.index.name = 'foo'
330-
self.check_round_trip(
331-
df, engine,
332-
write_kwargs={'compression': None}
333-
)
327+
self.check_round_trip( df, engine, write_kwargs={'compression': None})
328+
329+
def test_write_multiindex(self, pa_ge_070):
330+
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
331+
engine = pa_ge_070
332+
333+
df = pd.DataFrame({'A': [1, 2, 3]})
334+
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
335+
df.index = index
336+
self.check_round_trip(df, engine, write_kwargs={'compression': None})
334337

338+
def test_write_column_multiindex(self, engine):
335339
# column multi-index
336-
df.index = [0, 1, 2]
337-
df.columns = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)]),
340+
mi_columns = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
341+
df = pd.DataFrame(np.random.randn(4, 3), columns=mi_columns)
338342
self.check_error_on_write(df, engine, ValueError)
339343

340-
def test_multiindex_with_columns(self, engine):
341-
if engine == 'fastparquet':
342-
msg = "fastparquet doesn't support mulit-indexes as of 0.1.3"
343-
pytest.xfail(msg)
344+
def test_multiindex_with_columns(self, pa_ge_070):
344345

346+
engine = pa_ge_070
345347
dates = pd.date_range('01-Jan-2018', '01-Dec-2018', freq='MS')
346348
df = pd.DataFrame(randn(2 * len(dates), 3), columns=list('ABC'))
347349
index1 = pd.MultiIndex.from_product(

0 commit comments

Comments
 (0)