16
16
from pathlib import Path
17
17
import struct
18
18
import sys
19
- from typing import Any , AnyStr , BinaryIO , Dict , List , Optional , Sequence , Tuple , Union
19
+ from typing import (
20
+ Any ,
21
+ AnyStr ,
22
+ BinaryIO ,
23
+ Dict ,
24
+ List ,
25
+ Mapping ,
26
+ Optional ,
27
+ Sequence ,
28
+ Tuple ,
29
+ Union ,
30
+ )
20
31
import warnings
21
32
22
33
from dateutil .relativedelta import relativedelta
47
58
from pandas .core .indexes .base import Index
48
59
from pandas .core .series import Series
49
60
50
- from pandas .io .common import get_filepath_or_buffer , stringify_path
61
+ from pandas .io .common import (
62
+ get_compression_method ,
63
+ get_filepath_or_buffer ,
64
+ get_handle ,
65
+ infer_compression ,
66
+ stringify_path ,
67
+ )
51
68
52
69
_version_error = (
53
70
"Version of given Stata file is {version}. pandas supports importing "
@@ -1854,13 +1871,18 @@ def read_stata(
1854
1871
return data
1855
1872
1856
1873
1857
- def _open_file_binary_write (fname : FilePathOrBuffer ) -> Tuple [BinaryIO , bool ]:
1874
+ def _open_file_binary_write (
1875
+ fname : FilePathOrBuffer , compression : Union [str , Mapping [str , str ], None ],
1876
+ ) -> Tuple [BinaryIO , bool , Optional [Union [str , Mapping [str , str ]]]]:
1858
1877
"""
1859
1878
Open a binary file or no-op if file-like.
1860
1879
1861
1880
Parameters
1862
1881
----------
1863
1882
fname : string path, path object or buffer
1883
+ The file name or buffer.
1884
+ compression : {str, dict, None}
1885
+ The compression method to use.
1864
1886
1865
1887
Returns
1866
1888
-------
@@ -1871,9 +1893,21 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
1871
1893
"""
1872
1894
if hasattr (fname , "write" ):
1873
1895
# See https://github.com/python/mypy/issues/1424 for hasattr challenges
1874
- return fname , False # type: ignore
1896
+ return fname , False , None # type: ignore
1875
1897
elif isinstance (fname , (str , Path )):
1876
- return open (fname , "wb" ), True
1898
+ # Extract compression mode as given, if dict
1899
+ compression_typ , compression_args = get_compression_method (compression )
1900
+ compression_typ = infer_compression (fname , compression_typ )
1901
+ path_or_buf , _ , compression_typ , _ = get_filepath_or_buffer (
1902
+ fname , compression = compression_typ
1903
+ )
1904
+ if compression_typ is not None :
1905
+ compression = compression_args
1906
+ compression ["method" ] = compression_typ
1907
+ else :
1908
+ compression = None
1909
+ f , _ = get_handle (path_or_buf , "wb" , compression = compression , is_text = False )
1910
+ return f , True , compression
1877
1911
else :
1878
1912
raise TypeError ("fname must be a binary file, buffer or path-like." )
1879
1913
@@ -2050,6 +2084,17 @@ class StataWriter(StataParser):
2050
2084
variable_labels : dict
2051
2085
Dictionary containing columns as keys and variable labels as values.
2052
2086
Each label must be 80 characters or smaller.
2087
+ compression : str or dict, default 'infer'
2088
+ For on-the-fly compression of the output dta. If string, specifies
2089
+ compression mode. If dict, value at key 'method' specifies compression
2090
+ mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
2091
+ 'xz', None}. If compression mode is 'infer' and `fname` is path-like,
2092
+ then detect compression from the following extensions: '.gz', '.bz2',
2093
+ '.zip', or '.xz' (otherwise no compression). If dict and compression
2094
+ mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
2095
+ other entries passed as additional compression options.
2096
+
2097
+ .. versionadded:: 1.1.0
2053
2098
2054
2099
Returns
2055
2100
-------
@@ -2074,7 +2119,12 @@ class StataWriter(StataParser):
2074
2119
>>> writer = StataWriter('./data_file.dta', data)
2075
2120
>>> writer.write_file()
2076
2121
2077
- Or with dates
2122
+ Directly write a zip file
2123
+ >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
2124
+ >>> writer = StataWriter('./data_file.zip', data, compression=compression)
2125
+ >>> writer.write_file()
2126
+
2127
+ Save a DataFrame with dates
2078
2128
>>> from datetime import datetime
2079
2129
>>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
2080
2130
>>> writer = StataWriter('./date_data_file.dta', data, {'date' : 'tw'})
@@ -2094,6 +2144,7 @@ def __init__(
2094
2144
time_stamp : Optional [datetime .datetime ] = None ,
2095
2145
data_label : Optional [str ] = None ,
2096
2146
variable_labels : Optional [Dict [Label , str ]] = None ,
2147
+ compression : Union [str , Mapping [str , str ], None ] = "infer" ,
2097
2148
):
2098
2149
super ().__init__ ()
2099
2150
self ._convert_dates = {} if convert_dates is None else convert_dates
@@ -2102,6 +2153,8 @@ def __init__(
2102
2153
self ._data_label = data_label
2103
2154
self ._variable_labels = variable_labels
2104
2155
self ._own_file = True
2156
+ self ._compression = compression
2157
+ self ._output_file : Optional [BinaryIO ] = None
2105
2158
# attach nobs, nvars, data, varlist, typlist
2106
2159
self ._prepare_pandas (data )
2107
2160
@@ -2389,7 +2442,12 @@ def _encode_strings(self) -> None:
2389
2442
self .data [col ] = encoded
2390
2443
2391
2444
def write_file (self ) -> None :
2392
- self ._file , self ._own_file = _open_file_binary_write (self ._fname )
2445
+ self ._file , self ._own_file , compression = _open_file_binary_write (
2446
+ self ._fname , self ._compression
2447
+ )
2448
+ if compression is not None :
2449
+ self ._output_file = self ._file
2450
+ self ._file = BytesIO ()
2393
2451
try :
2394
2452
self ._write_header (data_label = self ._data_label , time_stamp = self ._time_stamp )
2395
2453
self ._write_map ()
@@ -2434,6 +2492,12 @@ def _close(self) -> None:
2434
2492
"""
2435
2493
# Some file-like objects might not support flush
2436
2494
assert self ._file is not None
2495
+ if self ._output_file is not None :
2496
+ assert isinstance (self ._file , BytesIO )
2497
+ bio = self ._file
2498
+ bio .seek (0 )
2499
+ self ._file = self ._output_file
2500
+ self ._file .write (bio .read ())
2437
2501
try :
2438
2502
self ._file .flush ()
2439
2503
except AttributeError :
@@ -2898,6 +2962,17 @@ class StataWriter117(StataWriter):
2898
2962
Smaller columns can be converted by including the column name. Using
2899
2963
StrLs can reduce output file size when strings are longer than 8
2900
2964
characters, and either frequently repeated or sparse.
2965
+ compression : str or dict, default 'infer'
2966
+ For on-the-fly compression of the output dta. If string, specifies
2967
+ compression mode. If dict, value at key 'method' specifies compression
2968
+ mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
2969
+ 'xz', None}. If compression mode is 'infer' and `fname` is path-like,
2970
+ then detect compression from the following extensions: '.gz', '.bz2',
2971
+ '.zip', or '.xz' (otherwise no compression). If dict and compression
2972
+ mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
2973
+ other entries passed as additional compression options.
2974
+
2975
+ .. versionadded:: 1.1.0
2901
2976
2902
2977
Returns
2903
2978
-------
@@ -2923,8 +2998,12 @@ class StataWriter117(StataWriter):
2923
2998
>>> writer = StataWriter117('./data_file.dta', data)
2924
2999
>>> writer.write_file()
2925
3000
2926
- Or with long strings stored in strl format
3001
+ Directly write a zip file
3002
+ >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3003
+ >>> writer = StataWriter117('./data_file.zip', data, compression=compression)
3004
+ >>> writer.write_file()
2927
3005
3006
+ Or with long strings stored in strl format
2928
3007
>>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
2929
3008
... columns=['strls'])
2930
3009
>>> writer = StataWriter117('./data_file_with_long_strings.dta', data,
@@ -2946,6 +3025,7 @@ def __init__(
2946
3025
data_label : Optional [str ] = None ,
2947
3026
variable_labels : Optional [Dict [Label , str ]] = None ,
2948
3027
convert_strl : Optional [Sequence [Label ]] = None ,
3028
+ compression : Union [str , Mapping [str , str ], None ] = "infer" ,
2949
3029
):
2950
3030
# Copy to new list since convert_strl might be modified later
2951
3031
self ._convert_strl : List [Label ] = []
@@ -2961,6 +3041,7 @@ def __init__(
2961
3041
time_stamp = time_stamp ,
2962
3042
data_label = data_label ,
2963
3043
variable_labels = variable_labels ,
3044
+ compression = compression ,
2964
3045
)
2965
3046
self ._map : Dict [str , int ] = {}
2966
3047
self ._strl_blob = b""
@@ -3281,6 +3362,17 @@ class StataWriterUTF8(StataWriter117):
3281
3362
The dta version to use. By default, uses the size of data to determine
3282
3363
the version. 118 is used if data.shape[1] <= 32767, and 119 is used
3283
3364
for storing larger DataFrames.
3365
+ compression : str or dict, default 'infer'
3366
+ For on-the-fly compression of the output dta. If string, specifies
3367
+ compression mode. If dict, value at key 'method' specifies compression
3368
+ mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
3369
+ 'xz', None}. If compression mode is 'infer' and `fname` is path-like,
3370
+ then detect compression from the following extensions: '.gz', '.bz2',
3371
+ '.zip', or '.xz' (otherwise no compression). If dict and compression
3372
+ mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
3373
+ other entries passed as additional compression options.
3374
+
3375
+ .. versionadded:: 1.1.0
3284
3376
3285
3377
Returns
3286
3378
-------
@@ -3308,6 +3400,11 @@ class StataWriterUTF8(StataWriter117):
3308
3400
>>> writer = StataWriterUTF8('./data_file.dta', data)
3309
3401
>>> writer.write_file()
3310
3402
3403
+ Directly write a zip file
3404
+ >>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3405
+ >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3406
+ >>> writer.write_file()
3407
+
3311
3408
Or with long strings stored in strl format
3312
3409
3313
3410
>>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
@@ -3331,6 +3428,7 @@ def __init__(
3331
3428
variable_labels : Optional [Dict [Label , str ]] = None ,
3332
3429
convert_strl : Optional [Sequence [Label ]] = None ,
3333
3430
version : Optional [int ] = None ,
3431
+ compression : Union [str , Mapping [str , str ], None ] = "infer" ,
3334
3432
):
3335
3433
if version is None :
3336
3434
version = 118 if data .shape [1 ] <= 32767 else 119
@@ -3352,6 +3450,7 @@ def __init__(
3352
3450
data_label = data_label ,
3353
3451
variable_labels = variable_labels ,
3354
3452
convert_strl = convert_strl ,
3453
+ compression = compression ,
3355
3454
)
3356
3455
# Override version set in StataWriter117 init
3357
3456
self ._dta_version = version
0 commit comments