16
16
from pathlib import Path
17
17
import struct
18
18
import sys
19
- from typing import Any , AnyStr , BinaryIO , Dict , List , Optional , Sequence , Tuple , Union
19
+ from typing import (
20
+ Any ,
21
+ AnyStr ,
22
+ BinaryIO ,
23
+ Dict ,
24
+ List ,
25
+ Optional ,
26
+ Sequence ,
27
+ Tuple ,
28
+ Union ,
29
+ Mapping ,
30
+ )
20
31
import warnings
21
32
22
33
from dateutil .relativedelta import relativedelta
47
58
from pandas .core .indexes .base import Index
48
59
from pandas .core .series import Series
49
60
50
- from pandas .io .common import get_filepath_or_buffer , stringify_path
61
+ from pandas .io .common import (
62
+ get_compression_method ,
63
+ get_filepath_or_buffer ,
64
+ stringify_path ,
65
+ get_handle ,
66
+ infer_compression ,
67
+ )
51
68
52
69
_version_error = (
53
70
"Version of given Stata file is {version}. pandas supports importing "
@@ -1854,7 +1871,9 @@ def read_stata(
1854
1871
return data
1855
1872
1856
1873
1857
- def _open_file_binary_write (fname : FilePathOrBuffer ) -> Tuple [BinaryIO , bool ]:
1874
+ def _open_file_binary_write (
1875
+ fname : FilePathOrBuffer , compression : Optional [str ]
1876
+ ) -> Tuple [BinaryIO , bool , Optional [Union [str , Mapping [str , str ]]]]:
1858
1877
"""
1859
1878
Open a binary file or no-op if file-like.
1860
1879
@@ -1871,9 +1890,15 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
1871
1890
"""
1872
1891
if hasattr (fname , "write" ):
1873
1892
# See https://github.com/python/mypy/issues/1424 for hasattr challenges
1874
- return fname , False # type: ignore
1893
+ return fname , False , None # type: ignore
1875
1894
elif isinstance (fname , (str , Path )):
1876
- return open (fname , "wb" ), True
1895
+ # Extract compression mode as given, if dict
1896
+ compression = infer_compression (fname , compression )
1897
+ path_or_buf , _ , compression , _ = get_filepath_or_buffer (
1898
+ fname , compression = compression
1899
+ )
1900
+ f , _ = get_handle (path_or_buf , "wb" , compression = compression , is_text = False )
1901
+ return f , True , compression
1877
1902
else :
1878
1903
raise TypeError ("fname must be a binary file, buffer or path-like." )
1879
1904
@@ -2050,6 +2075,13 @@ class StataWriter(StataParser):
2050
2075
variable_labels : dict
2051
2076
Dictionary containing columns as keys and variable labels as values.
2052
2077
Each label must be 80 characters or smaller.
2078
+ compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
2079
+ For on-the-fly compression of the output dta. If 'infer', then use
2080
+ gzip, bz2, zip or xz if path_or_buf is a string ending in
2081
+ '.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
2082
+ otherwise.
2083
+
2084
+ .. versionadded:: 1.2.0
2053
2085
2054
2086
Returns
2055
2087
-------
@@ -2094,6 +2126,7 @@ def __init__(
2094
2126
time_stamp : Optional [datetime .datetime ] = None ,
2095
2127
data_label : Optional [str ] = None ,
2096
2128
variable_labels : Optional [Dict [Label , str ]] = None ,
2129
+ compression : Optional [str ] = "infer" ,
2097
2130
):
2098
2131
super ().__init__ ()
2099
2132
self ._convert_dates = {} if convert_dates is None else convert_dates
@@ -2102,6 +2135,8 @@ def __init__(
2102
2135
self ._data_label = data_label
2103
2136
self ._variable_labels = variable_labels
2104
2137
self ._own_file = True
2138
+ self ._compression = compression
2139
+ self ._output_file : Optional [BinaryIO ] = None
2105
2140
# attach nobs, nvars, data, varlist, typlist
2106
2141
self ._prepare_pandas (data )
2107
2142
@@ -2389,7 +2424,12 @@ def _encode_strings(self) -> None:
2389
2424
self .data [col ] = encoded
2390
2425
2391
2426
def write_file (self ) -> None :
2392
- self ._file , self ._own_file = _open_file_binary_write (self ._fname )
2427
+ self ._file , self ._own_file , compression = _open_file_binary_write (
2428
+ self ._fname , self ._compression
2429
+ )
2430
+ if compression is not None :
2431
+ self ._output_file = self ._file
2432
+ self ._file = BytesIO ()
2393
2433
try :
2394
2434
self ._write_header (data_label = self ._data_label , time_stamp = self ._time_stamp )
2395
2435
self ._write_map ()
@@ -2434,6 +2474,12 @@ def _close(self) -> None:
2434
2474
"""
2435
2475
# Some file-like objects might not support flush
2436
2476
assert self ._file is not None
2477
+ if self ._output_file is not None :
2478
+ assert isinstance (self ._file , BytesIO )
2479
+ bio = self ._file
2480
+ bio .seek (0 )
2481
+ self ._file = self ._output_file
2482
+ self ._file .write (bio .read ())
2437
2483
try :
2438
2484
self ._file .flush ()
2439
2485
except AttributeError :
@@ -2946,6 +2992,7 @@ def __init__(
2946
2992
data_label : Optional [str ] = None ,
2947
2993
variable_labels : Optional [Dict [Label , str ]] = None ,
2948
2994
convert_strl : Optional [Sequence [Label ]] = None ,
2995
+ compression : Optional [str ] = "infer" ,
2949
2996
):
2950
2997
# Copy to new list since convert_strl might be modified later
2951
2998
self ._convert_strl : List [Label ] = []
@@ -2961,6 +3008,7 @@ def __init__(
2961
3008
time_stamp = time_stamp ,
2962
3009
data_label = data_label ,
2963
3010
variable_labels = variable_labels ,
3011
+ compression = compression ,
2964
3012
)
2965
3013
self ._map : Dict [str , int ] = {}
2966
3014
self ._strl_blob = b""
@@ -3331,6 +3379,7 @@ def __init__(
3331
3379
variable_labels : Optional [Dict [Label , str ]] = None ,
3332
3380
convert_strl : Optional [Sequence [Label ]] = None ,
3333
3381
version : Optional [int ] = None ,
3382
+ compression : Optional [str ] = "infer" ,
3334
3383
):
3335
3384
if version is None :
3336
3385
version = 118 if data .shape [1 ] <= 32767 else 119
@@ -3352,6 +3401,7 @@ def __init__(
3352
3401
data_label = data_label ,
3353
3402
variable_labels = variable_labels ,
3354
3403
convert_strl = convert_strl ,
3404
+ compression = compression ,
3355
3405
)
3356
3406
# Override version set in StataWriter117 init
3357
3407
self ._dta_version = version
0 commit comments