Skip to content

Commit 6e0927e

Browse files
maximvekslerjreback
authored andcommitted
BUG: read_parquet, to_parquet for s3 destinations (#19135)
1 parent 51d71cd commit 6e0927e

File tree

6 files changed

+102
-53
lines changed

6 files changed

+102
-53
lines changed

doc/source/whatsnew/v0.23.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ I/O
466466
- Bug in :func:`read_sas` where a file with 0 variables gave an ``AttributeError`` incorrectly. Now it gives an ``EmptyDataError`` (:issue:`18184`)
467467
- Bug in :func:`DataFrame.to_latex()` where pairs of braces meant to serve as invisible placeholders were escaped (:issue:`18667`)
468468
- Bug in :func:`read_json` where large numeric values were causing an ``OverflowError`` (:issue:`18842`)
469+
- Bug in :func:`DataFrame.to_parquet` where an exception was raised if the write destination is S3 (:issue:`19134`)
469470
-
470471

471472
Plotting

pandas/io/common.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,6 @@ def _is_url(url):
9191
return False
9292

9393

94-
def _is_s3_url(url):
95-
"""Check for an s3, s3n, or s3a url"""
96-
try:
97-
return parse_url(url).scheme in ['s3', 's3n', 's3a']
98-
except:
99-
return False
100-
101-
10294
def _expand_user(filepath_or_buffer):
10395
"""Return the argument with an initial component of ~ or ~user
10496
replaced by that user's home directory.
@@ -168,8 +160,16 @@ def _stringify_path(filepath_or_buffer):
168160
return filepath_or_buffer
169161

170162

163+
def is_s3_url(url):
164+
"""Check for an s3, s3n, or s3a url"""
165+
try:
166+
return parse_url(url).scheme in ['s3', 's3n', 's3a']
167+
except: # noqa
168+
return False
169+
170+
171171
def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
172-
compression=None):
172+
compression=None, mode=None):
173173
"""
174174
If the filepath_or_buffer is a url, translate and return the buffer.
175175
Otherwise passthrough.
@@ -179,10 +179,11 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
179179
filepath_or_buffer : a url, filepath (str, py.path.local or pathlib.Path),
180180
or buffer
181181
encoding : the encoding to use to decode py3 bytes, default is 'utf-8'
182+
mode : str, optional
182183
183184
Returns
184185
-------
185-
a filepath_or_buffer, the encoding, the compression
186+
a filepath_ or buffer or S3File instance, the encoding, the compression
186187
"""
187188
filepath_or_buffer = _stringify_path(filepath_or_buffer)
188189

@@ -195,11 +196,12 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
195196
reader = BytesIO(req.read())
196197
return reader, encoding, compression
197198

198-
if _is_s3_url(filepath_or_buffer):
199+
if is_s3_url(filepath_or_buffer):
199200
from pandas.io import s3
200201
return s3.get_filepath_or_buffer(filepath_or_buffer,
201202
encoding=encoding,
202-
compression=compression)
203+
compression=compression,
204+
mode=mode)
203205

204206
if isinstance(filepath_or_buffer, (compat.string_types,
205207
compat.binary_type,

pandas/io/parquet.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pandas import DataFrame, RangeIndex, Int64Index, get_option
66
from pandas.compat import string_types
77
from pandas.core.common import AbstractMethodError
8-
from pandas.io.common import get_filepath_or_buffer
8+
from pandas.io.common import get_filepath_or_buffer, is_s3_url
99

1010

1111
def get_engine(engine):
@@ -107,7 +107,7 @@ def write(self, df, path, compression='snappy',
107107
self.validate_dataframe(df)
108108
if self._pyarrow_lt_070:
109109
self._validate_write_lt_070(df)
110-
path, _, _ = get_filepath_or_buffer(path)
110+
path, _, _ = get_filepath_or_buffer(path, mode='wb')
111111

112112
if self._pyarrow_lt_060:
113113
table = self.api.Table.from_pandas(df, timestamps_to_ms=True)
@@ -194,14 +194,32 @@ def write(self, df, path, compression='snappy', **kwargs):
194194
# thriftpy/protocol/compact.py:339:
195195
# DeprecationWarning: tostring() is deprecated.
196196
# Use tobytes() instead.
197-
path, _, _ = get_filepath_or_buffer(path)
197+
198+
if is_s3_url(path):
199+
# path is s3:// so we need to open the s3file in 'wb' mode.
200+
# TODO: Support 'ab'
201+
202+
path, _, _ = get_filepath_or_buffer(path, mode='wb')
203+
# And pass the opened s3file to the fastparquet internal impl.
204+
kwargs['open_with'] = lambda path, _: path
205+
else:
206+
path, _, _ = get_filepath_or_buffer(path)
207+
198208
with catch_warnings(record=True):
199209
self.api.write(path, df,
200210
compression=compression, **kwargs)
201211

202212
def read(self, path, columns=None, **kwargs):
203-
path, _, _ = get_filepath_or_buffer(path)
204-
parquet_file = self.api.ParquetFile(path)
213+
if is_s3_url(path):
214+
# When path is s3:// an S3File is returned.
215+
# We need to retain the original path(str) while also
216+
# pass the S3File().open function to fsatparquet impl.
217+
s3, _, _ = get_filepath_or_buffer(path)
218+
parquet_file = self.api.ParquetFile(path, open_with=s3.s3.open)
219+
else:
220+
path, _, _ = get_filepath_or_buffer(path)
221+
parquet_file = self.api.ParquetFile(path)
222+
205223
return parquet_file.to_pandas(columns=columns, **kwargs)
206224

207225

pandas/io/s3.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@ def _strip_schema(url):
1919

2020

2121
def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
22-
compression=None):
22+
compression=None, mode=None):
23+
24+
if mode is None:
25+
mode = 'rb'
26+
2327
fs = s3fs.S3FileSystem(anon=False)
2428
try:
25-
filepath_or_buffer = fs.open(_strip_schema(filepath_or_buffer))
29+
filepath_or_buffer = fs.open(_strip_schema(filepath_or_buffer), mode)
2630
except (OSError, NoCredentialsError):
2731
# boto3 has troubles when trying to access a public file
2832
# when credentialed...
@@ -31,5 +35,5 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None,
3135
# A NoCredentialsError is raised if you don't have creds
3236
# for that bucket.
3337
fs = s3fs.S3FileSystem(anon=True)
34-
filepath_or_buffer = fs.open(_strip_schema(filepath_or_buffer))
38+
filepath_or_buffer = fs.open(_strip_schema(filepath_or_buffer), mode)
3539
return filepath_or_buffer, None, compression

pandas/tests/io/test_parquet.py

+54-30
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,22 @@ def test_cross_engine_fp_pa(df_cross_compat, pa, fp):
204204
tm.assert_frame_equal(result, df[['a', 'd']])
205205

206206

207+
def check_round_trip_equals(df, path, engine,
208+
write_kwargs, read_kwargs,
209+
expected, check_names):
210+
211+
df.to_parquet(path, engine, **write_kwargs)
212+
actual = read_parquet(path, engine, **read_kwargs)
213+
tm.assert_frame_equal(expected, actual,
214+
check_names=check_names)
215+
216+
# repeat
217+
df.to_parquet(path, engine, **write_kwargs)
218+
actual = read_parquet(path, engine, **read_kwargs)
219+
tm.assert_frame_equal(expected, actual,
220+
check_names=check_names)
221+
222+
207223
class Base(object):
208224

209225
def check_error_on_write(self, df, engine, exc):
@@ -212,28 +228,32 @@ def check_error_on_write(self, df, engine, exc):
212228
with tm.ensure_clean() as path:
213229
to_parquet(df, path, engine, compression=None)
214230

215-
def check_round_trip(self, df, engine, expected=None,
231+
def check_round_trip(self, df, engine, expected=None, path=None,
216232
write_kwargs=None, read_kwargs=None,
217233
check_names=True):
234+
218235
if write_kwargs is None:
219-
write_kwargs = {}
236+
write_kwargs = {'compression': None}
237+
220238
if read_kwargs is None:
221239
read_kwargs = {}
222-
with tm.ensure_clean() as path:
223-
df.to_parquet(path, engine, **write_kwargs)
224-
result = read_parquet(path, engine, **read_kwargs)
225240

226-
if expected is None:
227-
expected = df
228-
tm.assert_frame_equal(result, expected, check_names=check_names)
229-
230-
# repeat
231-
to_parquet(df, path, engine, **write_kwargs)
232-
result = pd.read_parquet(path, engine, **read_kwargs)
241+
if expected is None:
242+
expected = df
233243

234-
if expected is None:
235-
expected = df
236-
tm.assert_frame_equal(result, expected, check_names=check_names)
244+
if path is None:
245+
with tm.ensure_clean() as path:
246+
check_round_trip_equals(df, path, engine,
247+
write_kwargs=write_kwargs,
248+
read_kwargs=read_kwargs,
249+
expected=expected,
250+
check_names=check_names)
251+
else:
252+
check_round_trip_equals(df, path, engine,
253+
write_kwargs=write_kwargs,
254+
read_kwargs=read_kwargs,
255+
expected=expected,
256+
check_names=check_names)
237257

238258

239259
class TestBasic(Base):
@@ -251,7 +271,7 @@ def test_columns_dtypes(self, engine):
251271

252272
# unicode
253273
df.columns = [u'foo', u'bar']
254-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
274+
self.check_round_trip(df, engine)
255275

256276
def test_columns_dtypes_invalid(self, engine):
257277

@@ -292,7 +312,6 @@ def test_read_columns(self, engine):
292312

293313
expected = pd.DataFrame({'string': list('abc')})
294314
self.check_round_trip(df, engine, expected=expected,
295-
write_kwargs={'compression': None},
296315
read_kwargs={'columns': ['string']})
297316

298317
def test_write_index(self, engine):
@@ -304,7 +323,7 @@ def test_write_index(self, engine):
304323
pytest.skip("pyarrow is < 0.7.0")
305324

306325
df = pd.DataFrame({'A': [1, 2, 3]})
307-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
326+
self.check_round_trip(df, engine)
308327

309328
indexes = [
310329
[2, 3, 4],
@@ -315,15 +334,12 @@ def test_write_index(self, engine):
315334
# non-default index
316335
for index in indexes:
317336
df.index = index
318-
self.check_round_trip(
319-
df, engine,
320-
write_kwargs={'compression': None},
321-
check_names=check_names)
337+
self.check_round_trip(df, engine, check_names=check_names)
322338

323339
# index with meta-data
324340
df.index = [0, 1, 2]
325341
df.index.name = 'foo'
326-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
342+
self.check_round_trip(df, engine)
327343

328344
def test_write_multiindex(self, pa_ge_070):
329345
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
@@ -332,7 +348,7 @@ def test_write_multiindex(self, pa_ge_070):
332348
df = pd.DataFrame({'A': [1, 2, 3]})
333349
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
334350
df.index = index
335-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
351+
self.check_round_trip(df, engine)
336352

337353
def test_write_column_multiindex(self, engine):
338354
# column multi-index
@@ -426,6 +442,11 @@ def test_categorical_unsupported(self, pa_lt_070):
426442
df = pd.DataFrame({'a': pd.Categorical(list('abc'))})
427443
self.check_error_on_write(df, pa, NotImplementedError)
428444

445+
def test_s3_roundtrip(self, df_compat, s3_resource, pa):
446+
# GH #19134
447+
self.check_round_trip(df_compat, pa,
448+
path='s3://pandas-test/pyarrow.parquet')
449+
429450

430451
class TestParquetFastParquet(Base):
431452

@@ -436,7 +457,7 @@ def test_basic(self, fp, df_full):
436457
# additional supported types for fastparquet
437458
df['timedelta'] = pd.timedelta_range('1 day', periods=3)
438459

439-
self.check_round_trip(df, fp, write_kwargs={'compression': None})
460+
self.check_round_trip(df, fp)
440461

441462
@pytest.mark.skip(reason="not supported")
442463
def test_duplicate_columns(self, fp):
@@ -449,8 +470,7 @@ def test_duplicate_columns(self, fp):
449470
def test_bool_with_none(self, fp):
450471
df = pd.DataFrame({'a': [True, None, False]})
451472
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16')
452-
self.check_round_trip(df, fp, expected=expected,
453-
write_kwargs={'compression': None})
473+
self.check_round_trip(df, fp, expected=expected)
454474

455475
def test_unsupported(self, fp):
456476

@@ -466,7 +486,7 @@ def test_categorical(self, fp):
466486
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"):
467487
pytest.skip("CategoricalDtype not supported for older fp")
468488
df = pd.DataFrame({'a': pd.Categorical(list('abc'))})
469-
self.check_round_trip(df, fp, write_kwargs={'compression': None})
489+
self.check_round_trip(df, fp)
470490

471491
def test_datetime_tz(self, fp):
472492
# doesn't preserve tz
@@ -475,8 +495,7 @@ def test_datetime_tz(self, fp):
475495

476496
# warns on the coercion
477497
with catch_warnings(record=True):
478-
self.check_round_trip(df, fp, df.astype('datetime64[ns]'),
479-
write_kwargs={'compression': None})
498+
self.check_round_trip(df, fp, df.astype('datetime64[ns]'))
480499

481500
def test_filter_row_groups(self, fp):
482501
d = {'a': list(range(0, 3))}
@@ -486,3 +505,8 @@ def test_filter_row_groups(self, fp):
486505
row_group_offsets=1)
487506
result = read_parquet(path, fp, filters=[('a', '==', 0)])
488507
assert len(result) == 1
508+
509+
def test_s3_roundtrip(self, df_compat, s3_resource, fp):
510+
# GH #19134
511+
self.check_round_trip(df_compat, fp,
512+
path='s3://pandas-test/fastparquet.parquet')

pandas/tests/io/test_s3.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from pandas.io.common import _is_s3_url
1+
from pandas.io.common import is_s3_url
22

33

44
class TestS3URL(object):
55

66
def test_is_s3_url(self):
7-
assert _is_s3_url("s3://pandas/somethingelse.com")
8-
assert not _is_s3_url("s4://pandas/somethingelse.com")
7+
assert is_s3_url("s3://pandas/somethingelse.com")
8+
assert not is_s3_url("s4://pandas/somethingelse.com")

0 commit comments

Comments
 (0)