16
16
from pathlib import Path
17
17
import struct
18
18
import sys
19
- from typing import Any , AnyStr , BinaryIO , Dict , List , Optional , Sequence , Tuple , Union
19
+ from typing import (
20
+ Any ,
21
+ AnyStr ,
22
+ BinaryIO ,
23
+ Dict ,
24
+ List ,
25
+ Optional ,
26
+ Sequence ,
27
+ Tuple ,
28
+ Union ,
29
+ Mapping ,
30
+ )
20
31
import warnings
21
32
22
33
from dateutil .relativedelta import relativedelta
47
58
from pandas .core .indexes .base import Index
48
59
from pandas .core .series import Series
49
60
50
- from pandas .io .common import get_filepath_or_buffer , stringify_path
61
+ from pandas .io .common import (
62
+ get_compression_method ,
63
+ get_filepath_or_buffer ,
64
+ stringify_path ,
65
+ get_handle ,
66
+ infer_compression ,
67
+ )
51
68
52
69
_version_error = (
53
70
"Version of given Stata file is {version}. pandas supports importing "
@@ -1854,7 +1871,9 @@ def read_stata(
1854
1871
return data
1855
1872
1856
1873
1857
- def _open_file_binary_write (fname : FilePathOrBuffer ) -> Tuple [BinaryIO , bool ]:
1874
+ def _open_file_binary_write (
1875
+ fname : FilePathOrBuffer , compression : Optional [str ]
1876
+ ) -> Tuple [BinaryIO , bool , Optional [Union [str , Mapping [str , str ]]]]:
1858
1877
"""
1859
1878
Open a binary file or no-op if file-like.
1860
1879
@@ -1871,9 +1890,15 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
1871
1890
"""
1872
1891
if hasattr (fname , "write" ):
1873
1892
# See https://github.com/python/mypy/issues/1424 for hasattr challenges
1874
- return fname , False # type: ignore
1893
+ return fname , False , None # type: ignore
1875
1894
elif isinstance (fname , (str , Path )):
1876
- return open (fname , "wb" ), True
1895
+ # Extract compression mode as given, if dict
1896
+ compression = infer_compression (fname , compression )
1897
+ path_or_buf , _ , compression , _ = get_filepath_or_buffer (
1898
+ fname , compression = compression
1899
+ )
1900
+ f , _ = get_handle (path_or_buf , "wb" , compression = compression , is_text = False )
1901
+ return f , True , compression
1877
1902
else :
1878
1903
raise TypeError ("fname must be a binary file, buffer or path-like." )
1879
1904
@@ -2050,6 +2075,13 @@ class StataWriter(StataParser):
2050
2075
variable_labels : dict
2051
2076
Dictionary containing columns as keys and variable labels as values.
2052
2077
Each label must be 80 characters or smaller.
2078
+ compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
2079
+ For on-the-fly compression of the output dta. If 'infer', then use
2080
+ gzip, bz2, zip or xz if path_or_buf is a string ending in
2081
+ '.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
2082
+ otherwise.
2083
+
2084
+ .. versionadded:: 1.2.0
2053
2085
2054
2086
Returns
2055
2087
-------
@@ -2094,6 +2126,7 @@ def __init__(
2094
2126
time_stamp : Optional [datetime .datetime ] = None ,
2095
2127
data_label : Optional [str ] = None ,
2096
2128
variable_labels : Optional [Dict [Label , str ]] = None ,
2129
+ compression : Optional [str ] = "infer" ,
2097
2130
):
2098
2131
super ().__init__ ()
2099
2132
self ._convert_dates = {} if convert_dates is None else convert_dates
@@ -2102,6 +2135,8 @@ def __init__(
2102
2135
self ._data_label = data_label
2103
2136
self ._variable_labels = variable_labels
2104
2137
self ._own_file = True
2138
+ self ._compression = compression
2139
+ self ._output_file : Optional [BinaryIO ] = None
2105
2140
# attach nobs, nvars, data, varlist, typlist
2106
2141
self ._prepare_pandas (data )
2107
2142
@@ -2389,7 +2424,12 @@ def _encode_strings(self) -> None:
2389
2424
self .data [col ] = encoded
2390
2425
2391
2426
def write_file (self ) -> None :
2392
- self ._file , self ._own_file = _open_file_binary_write (self ._fname )
2427
+ self ._file , self ._own_file , compression = _open_file_binary_write (
2428
+ self ._fname , self ._compression
2429
+ )
2430
+ if compression is not None :
2431
+ self ._output_file = self ._file
2432
+ self ._file = BytesIO ()
2393
2433
try :
2394
2434
self ._write_header (data_label = self ._data_label , time_stamp = self ._time_stamp )
2395
2435
self ._write_map ()
@@ -2434,6 +2474,12 @@ def _close(self) -> None:
2434
2474
"""
2435
2475
# Some file-like objects might not support flush
2436
2476
assert self ._file is not None
2477
+ if self ._output_file is not None :
2478
+ assert isinstance (self ._file , BytesIO )
2479
+ bio = self ._file
2480
+ bio .seek (0 )
2481
+ self ._file = self ._output_file
2482
+ self ._file .write (bio .read ())
2437
2483
try :
2438
2484
self ._file .flush ()
2439
2485
except AttributeError :
@@ -2898,6 +2944,13 @@ class StataWriter117(StataWriter):
2898
2944
Smaller columns can be converted by including the column name. Using
2899
2945
StrLs can reduce output file size when strings are longer than 8
2900
2946
characters, and either frequently repeated or sparse.
2947
+ compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
2948
+ For on-the-fly compression of the output dta. If 'infer', then use
2949
+ gzip, bz2, zip or xz if path_or_buf is a string ending in
2950
+ '.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
2951
+ otherwise.
2952
+
2953
+ .. versionadded:: 1.2.0
2901
2954
2902
2955
Returns
2903
2956
-------
@@ -2946,6 +2999,7 @@ def __init__(
2946
2999
data_label : Optional [str ] = None ,
2947
3000
variable_labels : Optional [Dict [Label , str ]] = None ,
2948
3001
convert_strl : Optional [Sequence [Label ]] = None ,
3002
+ compression : Optional [str ] = "infer" ,
2949
3003
):
2950
3004
# Copy to new list since convert_strl might be modified later
2951
3005
self ._convert_strl : List [Label ] = []
@@ -2961,6 +3015,7 @@ def __init__(
2961
3015
time_stamp = time_stamp ,
2962
3016
data_label = data_label ,
2963
3017
variable_labels = variable_labels ,
3018
+ compression = compression ,
2964
3019
)
2965
3020
self ._map : Dict [str , int ] = {}
2966
3021
self ._strl_blob = b""
@@ -3281,6 +3336,13 @@ class StataWriterUTF8(StataWriter117):
3281
3336
The dta version to use. By default, uses the size of data to determine
3282
3337
the version. 118 is used if data.shape[1] <= 32767, and 119 is used
3283
3338
for storing larger DataFrames.
3339
+ compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
3340
+ For on-the-fly compression of the output dta. If 'infer', then use
3341
+ gzip, bz2, zip or xz if path_or_buf is a string ending in
3342
+ '.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
3343
+ otherwise.
3344
+
3345
+ .. versionadded:: 1.2.0
3284
3346
3285
3347
Returns
3286
3348
-------
@@ -3331,6 +3393,7 @@ def __init__(
3331
3393
variable_labels : Optional [Dict [Label , str ]] = None ,
3332
3394
convert_strl : Optional [Sequence [Label ]] = None ,
3333
3395
version : Optional [int ] = None ,
3396
+ compression : Optional [str ] = "infer" ,
3334
3397
):
3335
3398
if version is None :
3336
3399
version = 118 if data .shape [1 ] <= 32767 else 119
@@ -3352,6 +3415,7 @@ def __init__(
3352
3415
data_label = data_label ,
3353
3416
variable_labels = variable_labels ,
3354
3417
convert_strl = convert_strl ,
3418
+ compression = compression ,
3355
3419
)
3356
3420
# Override version set in StataWriter117 init
3357
3421
self ._dta_version = version
0 commit comments