Skip to content

Commit 8ba9082

Browse files
committed
delegate dict handling to _get_compression_method, type annotations
1 parent 60ea58c commit 8ba9082

File tree

2 files changed

+69
-37
lines changed

2 files changed

+69
-37
lines changed

pandas/io/common.py

+54-20
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
uses_relative)
1616
from urllib.request import pathname2url, urlopen
1717
import zipfile
18+
from typing import Dict
1819

1920
import pandas.compat as compat
2021
from pandas.errors import ( # noqa
@@ -233,6 +234,39 @@ def file_path_to_url(path):
233234
}
234235

235236

237+
def _get_compression_method(compression: (str, Dict)):
238+
"""
239+
Simplifies a compression argument to a compression method string and
240+
a dict containing additional arguments.
241+
242+
Parameters
243+
----------
244+
compression : str or dict
245+
If string, specifies the compression method. If dict, value at key
246+
'method' specifies compression method.
247+
248+
Returns
249+
-------
250+
tuple of ({compression method}, str
251+
{compression arguments}, dict)
252+
253+
Raises
254+
------
255+
ValueError on dict missing 'method' key
256+
"""
257+
compression_args = {}
258+
# Handle dict
259+
if isinstance(compression, dict):
260+
compression_args = compression.copy()
261+
try:
262+
compression = compression['method']
263+
compression_args.pop('method')
264+
except KeyError:
265+
raise ValueError("If dict, compression "
266+
"must have key 'method'")
267+
return compression, compression_args
268+
269+
236270
def _infer_compression(filepath_or_buffer, compression):
237271
"""
238272
Get the compression method for filepath_or_buffer. If compression mode is
@@ -266,13 +300,8 @@ def _infer_compression(filepath_or_buffer, compression):
266300
ValueError on invalid compression specified
267301
"""
268302

269-
# Handle compression method as dict
270-
if isinstance(compression, dict):
271-
try:
272-
compression = compression['method']
273-
except KeyError:
274-
raise ValueError("Compression dict must have key "
275-
"'method'")
303+
# Handle compression as dict
304+
compression, _ = _get_compression_method(compression)
276305

277306
# No compression has been explicitly specified
278307
if compression is None:
@@ -355,31 +384,31 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
355384
path_or_buf = _stringify_path(path_or_buf)
356385
is_path = isinstance(path_or_buf, str)
357386

358-
compression_method = None
387+
compression, compression_args = _get_compression_method(compression)
359388
if is_path:
360-
compression_method = _infer_compression(path_or_buf, compression)
389+
compression = _infer_compression(path_or_buf, compression)
361390

362-
if compression_method:
391+
if compression:
363392

364393
# GZ Compression
365-
if compression_method == 'gzip':
394+
if compression == 'gzip':
366395
if is_path:
367396
f = gzip.open(path_or_buf, mode)
368397
else:
369398
f = gzip.GzipFile(fileobj=path_or_buf)
370399

371400
# BZ Compression
372-
elif compression_method == 'bz2':
401+
elif compression == 'bz2':
373402
if is_path:
374403
f = bz2.BZ2File(path_or_buf, mode)
375404
else:
376405
f = bz2.BZ2File(path_or_buf)
377406

378407
# ZIP Compression
379-
elif compression_method == 'zip':
408+
elif compression == 'zip':
380409
arcname = None
381-
if isinstance(compression, dict) and 'arcname' in compression:
382-
arcname = compression['arcname']
410+
if 'arcname' in compression_args:
411+
arcname = compression_args['arcname']
383412
zf = BytesZipFile(path_or_buf, mode, arcname=arcname)
384413
# Ensure the container is closed as well.
385414
handles.append(zf)
@@ -398,9 +427,14 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
398427
.format(zip_names))
399428

400429
# XZ Compression
401-
elif compression_method == 'xz':
430+
elif compression == 'xz':
402431
f = lzma.LZMAFile(path_or_buf, mode)
403432

433+
# Unrecognized Compression
434+
else:
435+
msg = 'Unrecognized compression type: {}'.format(compression)
436+
raise ValueError(msg)
437+
404438
handles.append(f)
405439

406440
elif is_path:
@@ -416,7 +450,7 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
416450
handles.append(f)
417451

418452
# Convert BytesIO or file objects passed with an encoding
419-
if is_text and (compression_method or isinstance(f, need_text_wrapping)):
453+
if is_text and (compression or isinstance(f, need_text_wrapping)):
420454
from io import TextIOWrapper
421455
f = TextIOWrapper(f, encoding=encoding, newline='')
422456
handles.append(f)
@@ -446,15 +480,15 @@ class BytesZipFile(zipfile.ZipFile, BytesIO): # type: ignore
446480
"""
447481
# GH 17778
448482
def __init__(self, file, mode, compression=zipfile.ZIP_DEFLATED,
449-
arcname=None, **kwargs):
483+
arcname: (str, zipfile.ZipInfo) = None, **kwargs):
450484
if mode in ['wb', 'rb']:
451485
mode = mode.replace('b', '')
452486
self.arcname = arcname
453-
super(BytesZipFile, self).__init__(file, mode, compression, **kwargs)
487+
super().__init__(file, mode, compression, **kwargs)
454488

455489
def write(self, data):
456490
arcname = self.filename if self.arcname is None else self.arcname
457-
super(BytesZipFile, self).writestr(arcname, data)
491+
super().writestr(arcname, data)
458492

459493
@property
460494
def closed(self):

pandas/io/formats/csvs.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import warnings
1010
from zipfile import ZipFile
11+
from typing import Dict
1112

1213
import numpy as np
1314

@@ -18,38 +19,32 @@
1819
from pandas.core.dtypes.missing import notna
1920

2021
from pandas.io.common import (
21-
UnicodeWriter, _get_handle, _infer_compression, get_filepath_or_buffer)
22+
UnicodeWriter, _get_handle, _infer_compression, get_filepath_or_buffer,
23+
_get_compression_method)
2224

2325

2426
class CSVFormatter(object):
2527

2628
def __init__(self, obj, path_or_buf=None, sep=",", na_rep='',
2729
float_format=None, cols=None, header=True, index=True,
2830
index_label=None, mode='w', nanRep=None, encoding=None,
29-
compression='infer', quoting=None, line_terminator='\n',
30-
chunksize=None, tupleize_cols=False, quotechar='"',
31-
date_format=None, doublequote=True, escapechar=None,
32-
decimal='.'):
31+
compression: (str, Dict) = 'infer', quoting=None,
32+
line_terminator='\n', chunksize=None, tupleize_cols=False,
33+
quotechar='"', date_format=None, doublequote=True,
34+
escapechar=None, decimal='.'):
3335

3436
self.obj = obj
3537

3638
if path_or_buf is None:
3739
path_or_buf = StringIO()
3840

39-
self._compression_arg = compression
40-
compression_mode = compression
41-
4241
# Extract compression mode as given, if dict
43-
if isinstance(compression, dict):
44-
try:
45-
compression_mode = compression['method']
46-
except KeyError:
47-
raise ValueError("If dict, compression must have key "
48-
"'method'")
42+
compression, self.compression_args \
43+
= _get_compression_method(compression)
4944

5045
self.path_or_buf, _, _, _ = get_filepath_or_buffer(
5146
path_or_buf, encoding=encoding,
52-
compression=compression_mode, mode=mode
47+
compression=compression, mode=mode
5348
)
5449
self.sep = sep
5550
self.na_rep = na_rep
@@ -162,7 +157,8 @@ def save(self):
162157
else:
163158
f, handles = _get_handle(self.path_or_buf, self.mode,
164159
encoding=self.encoding,
165-
compression=self._compression_arg)
160+
compression=dict(self.compression_args,
161+
method=self.compression))
166162
close = True
167163

168164
try:
@@ -186,9 +182,11 @@ def save(self):
186182
if hasattr(self.path_or_buf, 'write'):
187183
self.path_or_buf.write(buf)
188184
else:
185+
compression = dict(self.compression_args,
186+
method=self.compression)
189187
f, handles = _get_handle(self.path_or_buf, self.mode,
190188
encoding=self.encoding,
191-
compression=self._compression_arg)
189+
compression=compression)
192190
f.write(buf)
193191
close = True
194192
if close:

0 commit comments

Comments
 (0)