15
15
uses_relative )
16
16
from urllib .request import pathname2url , urlopen
17
17
import zipfile
18
+ from typing import Dict
18
19
19
20
import pandas .compat as compat
20
21
from pandas .errors import ( # noqa
@@ -233,6 +234,39 @@ def file_path_to_url(path):
233
234
}
234
235
235
236
237
+ def _get_compression_method (compression : (str , Dict )):
238
+ """
239
+ Simplifies a compression argument to a compression method string and
240
+ a dict containing additional arguments.
241
+
242
+ Parameters
243
+ ----------
244
+ compression : str or dict
245
+ If string, specifies the compression method. If dict, value at key
246
+ 'method' specifies compression method.
247
+
248
+ Returns
249
+ -------
250
+ tuple of ({compression method}, str
251
+ {compression arguments}, dict)
252
+
253
+ Raises
254
+ ------
255
+ ValueError on dict missing 'method' key
256
+ """
257
+ compression_args = {}
258
+ # Handle dict
259
+ if isinstance (compression , dict ):
260
+ compression_args = compression .copy ()
261
+ try :
262
+ compression = compression ['method' ]
263
+ compression_args .pop ('method' )
264
+ except KeyError :
265
+ raise ValueError ("If dict, compression "
266
+ "must have key 'method'" )
267
+ return compression , compression_args
268
+
269
+
236
270
def _infer_compression (filepath_or_buffer , compression ):
237
271
"""
238
272
Get the compression method for filepath_or_buffer. If compression mode is
@@ -266,13 +300,8 @@ def _infer_compression(filepath_or_buffer, compression):
266
300
ValueError on invalid compression specified
267
301
"""
268
302
269
- # Handle compression method as dict
270
- if isinstance (compression , dict ):
271
- try :
272
- compression = compression ['method' ]
273
- except KeyError :
274
- raise ValueError ("Compression dict must have key "
275
- "'method'" )
303
+ # Handle compression as dict
304
+ compression , _ = _get_compression_method (compression )
276
305
277
306
# No compression has been explicitly specified
278
307
if compression is None :
@@ -355,31 +384,31 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
355
384
path_or_buf = _stringify_path (path_or_buf )
356
385
is_path = isinstance (path_or_buf , str )
357
386
358
- compression_method = None
387
+ compression , compression_args = _get_compression_method ( compression )
359
388
if is_path :
360
- compression_method = _infer_compression (path_or_buf , compression )
389
+ compression = _infer_compression (path_or_buf , compression )
361
390
362
- if compression_method :
391
+ if compression :
363
392
364
393
# GZ Compression
365
- if compression_method == 'gzip' :
394
+ if compression == 'gzip' :
366
395
if is_path :
367
396
f = gzip .open (path_or_buf , mode )
368
397
else :
369
398
f = gzip .GzipFile (fileobj = path_or_buf )
370
399
371
400
# BZ Compression
372
- elif compression_method == 'bz2' :
401
+ elif compression == 'bz2' :
373
402
if is_path :
374
403
f = bz2 .BZ2File (path_or_buf , mode )
375
404
else :
376
405
f = bz2 .BZ2File (path_or_buf )
377
406
378
407
# ZIP Compression
379
- elif compression_method == 'zip' :
408
+ elif compression == 'zip' :
380
409
arcname = None
381
- if isinstance ( compression , dict ) and 'arcname' in compression :
382
- arcname = compression ['arcname' ]
410
+ if 'arcname' in compression_args :
411
+ arcname = compression_args ['arcname' ]
383
412
zf = BytesZipFile (path_or_buf , mode , arcname = arcname )
384
413
# Ensure the container is closed as well.
385
414
handles .append (zf )
@@ -398,9 +427,14 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
398
427
.format (zip_names ))
399
428
400
429
# XZ Compression
401
- elif compression_method == 'xz' :
430
+ elif compression == 'xz' :
402
431
f = lzma .LZMAFile (path_or_buf , mode )
403
432
433
+ # Unrecognized Compression
434
+ else :
435
+ msg = 'Unrecognized compression type: {}' .format (compression )
436
+ raise ValueError (msg )
437
+
404
438
handles .append (f )
405
439
406
440
elif is_path :
@@ -416,7 +450,7 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
416
450
handles .append (f )
417
451
418
452
# Convert BytesIO or file objects passed with an encoding
419
- if is_text and (compression_method or isinstance (f , need_text_wrapping )):
453
+ if is_text and (compression or isinstance (f , need_text_wrapping )):
420
454
from io import TextIOWrapper
421
455
f = TextIOWrapper (f , encoding = encoding , newline = '' )
422
456
handles .append (f )
@@ -446,15 +480,15 @@ class BytesZipFile(zipfile.ZipFile, BytesIO): # type: ignore
446
480
"""
447
481
# GH 17778
448
482
def __init__ (self , file , mode , compression = zipfile .ZIP_DEFLATED ,
449
- arcname = None , ** kwargs ):
483
+ arcname : ( str , zipfile . ZipInfo ) = None , ** kwargs ):
450
484
if mode in ['wb' , 'rb' ]:
451
485
mode = mode .replace ('b' , '' )
452
486
self .arcname = arcname
453
- super (BytesZipFile , self ).__init__ (file , mode , compression , ** kwargs )
487
+ super ().__init__ (file , mode , compression , ** kwargs )
454
488
455
489
def write (self , data ):
456
490
arcname = self .filename if self .arcname is None else self .arcname
457
- super (BytesZipFile , self ).writestr (arcname , data )
491
+ super ().writestr (arcname , data )
458
492
459
493
@property
460
494
def closed (self ):
0 commit comments