Skip to content

Commit efa6ec5

Browse files
committed
ENH: Add compression to stata writers
Add compression
1 parent 6be51cb commit efa6ec5

File tree

4 files changed

+189
-9
lines changed

4 files changed

+189
-9
lines changed

doc/source/whatsnew/v1.1.0.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,12 @@ Other enhancements
226226
- :meth:`~pandas.core.resample.Resampler.interpolate` now supports SciPy interpolation method :class:`scipy.interpolate.CubicSpline` as method ``cubicspline`` (:issue:`33670`)
227227
- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals`
228228
method, similarly to :meth:`Series.equals` (:issue:`27081`).
229-
-
229+
- :meth:`~pandas.core.frame.DataFrame.to_stata` supports compression using the ``compression``
230+
keyword argument. Compression can either be inferred or explicitly set using a string or a
231+
dictionary containing both the method and any additional arguments that are passed to the
232+
compression library. Compression was also added to the low-level Stata-file writers
233+
:class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`,
234+
and :class:`~pandas.io.stata.StataWriterUTF8` (:issues:`26599`).
230235

231236
.. ---------------------------------------------------------------------------
232237

pandas/core/frame.py

+16
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Iterable,
2626
Iterator,
2727
List,
28+
Mapping,
2829
Optional,
2930
Sequence,
3031
Set,
@@ -1970,6 +1971,7 @@ def to_stata(
19701971
variable_labels: Optional[Dict[Label, str]] = None,
19711972
version: Optional[int] = 114,
19721973
convert_strl: Optional[Sequence[Label]] = None,
1974+
compression: Union[str, Mapping[str, str], None] = "infer",
19731975
) -> None:
19741976
"""
19751977
Export DataFrame object to Stata dta format.
@@ -2033,6 +2035,19 @@ def to_stata(
20332035
20342036
.. versionadded:: 0.23.0
20352037
2038+
compression : str or dict, default 'infer'
2039+
For on-the-fly compression of the output dta. If string, specifies
2040+
compression mode. If dict, value at key 'method' specifies
2041+
compression mode. Compression mode must be one of {'infer', 'gzip',
2042+
'bz2', 'zip', 'xz', None}. If compression mode is 'infer' and
2043+
`fname` is path-like, then detect compression from the following
2044+
extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise no
2045+
compression). If dict and compression mode is one of {'zip',
2046+
'gzip', 'bz2'}, or inferred as one of the above, other entries
2047+
passed as additional compression options.
2048+
2049+
.. versionadded:: 1.1.0
2050+
20362051
Raises
20372052
------
20382053
NotImplementedError
@@ -2088,6 +2103,7 @@ def to_stata(
20882103
data_label=data_label,
20892104
write_index=write_index,
20902105
variable_labels=variable_labels,
2106+
compression=compression,
20912107
**kwargs,
20922108
)
20932109
writer.write_file()

pandas/io/stata.py

+107-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,18 @@
1616
from pathlib import Path
1717
import struct
1818
import sys
19-
from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union
19+
from typing import (
20+
Any,
21+
AnyStr,
22+
BinaryIO,
23+
Dict,
24+
List,
25+
Mapping,
26+
Optional,
27+
Sequence,
28+
Tuple,
29+
Union,
30+
)
2031
import warnings
2132

2233
from dateutil.relativedelta import relativedelta
@@ -47,7 +58,13 @@
4758
from pandas.core.indexes.base import Index
4859
from pandas.core.series import Series
4960

50-
from pandas.io.common import get_filepath_or_buffer, stringify_path
61+
from pandas.io.common import (
62+
get_compression_method,
63+
get_filepath_or_buffer,
64+
get_handle,
65+
infer_compression,
66+
stringify_path,
67+
)
5168

5269
_version_error = (
5370
"Version of given Stata file is {version}. pandas supports importing "
@@ -1854,13 +1871,18 @@ def read_stata(
18541871
return data
18551872

18561873

1857-
def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
1874+
def _open_file_binary_write(
1875+
fname: FilePathOrBuffer, compression: Union[str, Mapping[str, str], None],
1876+
) -> Tuple[BinaryIO, bool, Optional[Union[str, Mapping[str, str]]]]:
18581877
"""
18591878
Open a binary file or no-op if file-like.
18601879
18611880
Parameters
18621881
----------
18631882
fname : string path, path object or buffer
1883+
The file name or buffer.
1884+
compression : {str, dict, None}
1885+
The compression method to use.
18641886
18651887
Returns
18661888
-------
@@ -1871,9 +1893,21 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
18711893
"""
18721894
if hasattr(fname, "write"):
18731895
# See https://github.com/python/mypy/issues/1424 for hasattr challenges
1874-
return fname, False # type: ignore
1896+
return fname, False, None # type: ignore
18751897
elif isinstance(fname, (str, Path)):
1876-
return open(fname, "wb"), True
1898+
# Extract compression mode as given, if dict
1899+
compression_typ, compression_args = get_compression_method(compression)
1900+
compression_typ = infer_compression(fname, compression_typ)
1901+
path_or_buf, _, compression_typ, _ = get_filepath_or_buffer(
1902+
fname, compression=compression_typ
1903+
)
1904+
if compression_typ is not None:
1905+
compression = compression_args
1906+
compression["method"] = compression_typ
1907+
else:
1908+
compression = None
1909+
f, _ = get_handle(path_or_buf, "wb", compression=compression, is_text=False)
1910+
return f, True, compression
18771911
else:
18781912
raise TypeError("fname must be a binary file, buffer or path-like.")
18791913

@@ -2050,6 +2084,17 @@ class StataWriter(StataParser):
20502084
variable_labels : dict
20512085
Dictionary containing columns as keys and variable labels as values.
20522086
Each label must be 80 characters or smaller.
2087+
compression : str or dict, default 'infer'
2088+
For on-the-fly compression of the output dta. If string, specifies
2089+
compression mode. If dict, value at key 'method' specifies compression
2090+
mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
2091+
'xz', None}. If compression mode is 'infer' and `fname` is path-like,
2092+
then detect compression from the following extensions: '.gz', '.bz2',
2093+
'.zip', or '.xz' (otherwise no compression). If dict and compression
2094+
mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
2095+
other entries passed as additional compression options.
2096+
2097+
.. versionadded:: 1.1.0
20532098
20542099
Returns
20552100
-------
@@ -2074,7 +2119,12 @@ class StataWriter(StataParser):
20742119
>>> writer = StataWriter('./data_file.dta', data)
20752120
>>> writer.write_file()
20762121
2077-
Or with dates
2122+
Directly write a zip file
2123+
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
2124+
>>> writer = StataWriter('./data_file.zip', data, compression=compression)
2125+
>>> writer.write_file()
2126+
2127+
Save a DataFrame with dates
20782128
>>> from datetime import datetime
20792129
>>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
20802130
>>> writer = StataWriter('./date_data_file.dta', data, {'date' : 'tw'})
@@ -2094,6 +2144,7 @@ def __init__(
20942144
time_stamp: Optional[datetime.datetime] = None,
20952145
data_label: Optional[str] = None,
20962146
variable_labels: Optional[Dict[Label, str]] = None,
2147+
compression: Union[str, Mapping[str, str], None] = "infer",
20972148
):
20982149
super().__init__()
20992150
self._convert_dates = {} if convert_dates is None else convert_dates
@@ -2102,6 +2153,8 @@ def __init__(
21022153
self._data_label = data_label
21032154
self._variable_labels = variable_labels
21042155
self._own_file = True
2156+
self._compression = compression
2157+
self._output_file: Optional[BinaryIO] = None
21052158
# attach nobs, nvars, data, varlist, typlist
21062159
self._prepare_pandas(data)
21072160

@@ -2389,7 +2442,12 @@ def _encode_strings(self) -> None:
23892442
self.data[col] = encoded
23902443

23912444
def write_file(self) -> None:
2392-
self._file, self._own_file = _open_file_binary_write(self._fname)
2445+
self._file, self._own_file, compression = _open_file_binary_write(
2446+
self._fname, self._compression
2447+
)
2448+
if compression is not None:
2449+
self._output_file = self._file
2450+
self._file = BytesIO()
23932451
try:
23942452
self._write_header(data_label=self._data_label, time_stamp=self._time_stamp)
23952453
self._write_map()
@@ -2434,6 +2492,12 @@ def _close(self) -> None:
24342492
"""
24352493
# Some file-like objects might not support flush
24362494
assert self._file is not None
2495+
if self._output_file is not None:
2496+
assert isinstance(self._file, BytesIO)
2497+
bio = self._file
2498+
bio.seek(0)
2499+
self._file = self._output_file
2500+
self._file.write(bio.read())
24372501
try:
24382502
self._file.flush()
24392503
except AttributeError:
@@ -2898,6 +2962,17 @@ class StataWriter117(StataWriter):
28982962
Smaller columns can be converted by including the column name. Using
28992963
StrLs can reduce output file size when strings are longer than 8
29002964
characters, and either frequently repeated or sparse.
2965+
compression : str or dict, default 'infer'
2966+
For on-the-fly compression of the output dta. If string, specifies
2967+
compression mode. If dict, value at key 'method' specifies compression
2968+
mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
2969+
'xz', None}. If compression mode is 'infer' and `fname` is path-like,
2970+
then detect compression from the following extensions: '.gz', '.bz2',
2971+
'.zip', or '.xz' (otherwise no compression). If dict and compression
2972+
mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
2973+
other entries passed as additional compression options.
2974+
2975+
.. versionadded:: 1.1.0
29012976
29022977
Returns
29032978
-------
@@ -2923,8 +2998,12 @@ class StataWriter117(StataWriter):
29232998
>>> writer = StataWriter117('./data_file.dta', data)
29242999
>>> writer.write_file()
29253000
2926-
Or with long strings stored in strl format
3001+
Directly write a zip file
3002+
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3003+
>>> writer = StataWriter117('./data_file.zip', data, compression=compression)
3004+
>>> writer.write_file()
29273005
3006+
Or with long strings stored in strl format
29283007
>>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
29293008
... columns=['strls'])
29303009
>>> writer = StataWriter117('./data_file_with_long_strings.dta', data,
@@ -2946,6 +3025,7 @@ def __init__(
29463025
data_label: Optional[str] = None,
29473026
variable_labels: Optional[Dict[Label, str]] = None,
29483027
convert_strl: Optional[Sequence[Label]] = None,
3028+
compression: Union[str, Mapping[str, str], None] = "infer",
29493029
):
29503030
# Copy to new list since convert_strl might be modified later
29513031
self._convert_strl: List[Label] = []
@@ -2961,6 +3041,7 @@ def __init__(
29613041
time_stamp=time_stamp,
29623042
data_label=data_label,
29633043
variable_labels=variable_labels,
3044+
compression=compression,
29643045
)
29653046
self._map: Dict[str, int] = {}
29663047
self._strl_blob = b""
@@ -3281,6 +3362,17 @@ class StataWriterUTF8(StataWriter117):
32813362
The dta version to use. By default, uses the size of data to determine
32823363
the version. 118 is used if data.shape[1] <= 32767, and 119 is used
32833364
for storing larger DataFrames.
3365+
compression : str or dict, default 'infer'
3366+
For on-the-fly compression of the output dta. If string, specifies
3367+
compression mode. If dict, value at key 'method' specifies compression
3368+
mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
3369+
'xz', None}. If compression mode is 'infer' and `fname` is path-like,
3370+
then detect compression from the following extensions: '.gz', '.bz2',
3371+
'.zip', or '.xz' (otherwise no compression). If dict and compression
3372+
mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
3373+
other entries passed as additional compression options.
3374+
3375+
.. versionadded:: 1.1.0
32843376
32853377
Returns
32863378
-------
@@ -3308,6 +3400,11 @@ class StataWriterUTF8(StataWriter117):
33083400
>>> writer = StataWriterUTF8('./data_file.dta', data)
33093401
>>> writer.write_file()
33103402
3403+
Directly write a zip file
3404+
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
3405+
>>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
3406+
>>> writer.write_file()
3407+
33113408
Or with long strings stored in strl format
33123409
33133410
>>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
@@ -3331,6 +3428,7 @@ def __init__(
33313428
variable_labels: Optional[Dict[Label, str]] = None,
33323429
convert_strl: Optional[Sequence[Label]] = None,
33333430
version: Optional[int] = None,
3431+
compression: Union[str, Mapping[str, str], None] = "infer",
33343432
):
33353433
if version is None:
33363434
version = 118 if data.shape[1] <= 32767 else 119
@@ -3352,6 +3450,7 @@ def __init__(
33523450
data_label=data_label,
33533451
variable_labels=variable_labels,
33543452
convert_strl=convert_strl,
3453+
compression=compression,
33553454
)
33563455
# Override version set in StataWriter117 init
33573456
self._dta_version = version

pandas/tests/io/test_stata.py

+60
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import bz2
12
import datetime as dt
23
from datetime import datetime
34
import gzip
45
import io
6+
import lzma
57
import os
68
import struct
79
import warnings
10+
import zipfile
811

912
import numpy as np
1013
import pytest
@@ -1853,3 +1856,60 @@ def test_writer_118_exceptions(self):
18531856
with tm.ensure_clean() as path:
18541857
with pytest.raises(ValueError, match="You must use version 119"):
18551858
StataWriterUTF8(path, df, version=118)
1859+
1860+
1861+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
1862+
@pytest.mark.parametrize("use_dict", [True, False])
1863+
@pytest.mark.parametrize("infer", [True, False])
1864+
def test_compression(compression, version, use_dict, infer):
1865+
file_name = "dta_inferred_compression.dta"
1866+
if compression:
1867+
file_ext = "gz" if compression == "gzip" and not use_dict else compression
1868+
file_name += f".{file_ext}"
1869+
compression_arg = compression
1870+
if infer:
1871+
compression_arg = "infer"
1872+
if use_dict:
1873+
compression_arg = {"method": compression}
1874+
1875+
df = DataFrame(np.random.randn(10, 2), columns=list("AB"))
1876+
df.index.name = "index"
1877+
with tm.ensure_clean(file_name) as path:
1878+
df.to_stata(path, version=version, compression=compression_arg)
1879+
if compression == "gzip":
1880+
with gzip.open(path, "rb") as comp:
1881+
fp = io.BytesIO(comp.read())
1882+
elif compression == "zip":
1883+
with zipfile.ZipFile(path, "r") as comp:
1884+
fp = io.BytesIO(comp.read(comp.filelist[0]))
1885+
elif compression == "bz2":
1886+
with bz2.open(path, "rb") as comp:
1887+
fp = io.BytesIO(comp.read())
1888+
elif compression == "xz":
1889+
with lzma.open(path, "rb") as comp:
1890+
fp = io.BytesIO(comp.read())
1891+
elif compression is None:
1892+
fp = path
1893+
reread = read_stata(fp, index_col="index")
1894+
tm.assert_frame_equal(reread, df)
1895+
1896+
1897+
@pytest.mark.parametrize("method", ["zip", "infer"])
1898+
@pytest.mark.parametrize("file_ext", [None, "dta", "zip"])
1899+
def test_compression_dict(method, file_ext):
1900+
file_name = f"test.{file_ext}"
1901+
archive_name = "test.dta"
1902+
df = DataFrame(np.random.randn(10, 2), columns=list("AB"))
1903+
df.index.name = "index"
1904+
with tm.ensure_clean(file_name) as path:
1905+
compression = {"method": method, "archive_name": archive_name}
1906+
df.to_stata(path, compression=compression)
1907+
if method == "zip" or file_ext == "zip":
1908+
zp = zipfile.ZipFile(path, "r")
1909+
assert len(zp.filelist) == 1
1910+
assert zp.filelist[0].filename == archive_name
1911+
fp = io.BytesIO(zp.read(zp.filelist[0]))
1912+
else:
1913+
fp = path
1914+
reread = read_stata(fp, index_col="index")
1915+
tm.assert_frame_equal(reread, df)

0 commit comments

Comments
 (0)