Skip to content

Commit f1c87bf

Browse files
committed
ENH: Add compression to stata exporters
Add standard compression optons to stata exporters closes #26599
1 parent ebb727e commit f1c87bf

File tree

3 files changed

+110
-6
lines changed

3 files changed

+110
-6
lines changed

pandas/core/frame.py

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Iterable,
2626
Iterator,
2727
List,
28+
Mapping,
2829
Optional,
2930
Sequence,
3031
Set,
@@ -1975,6 +1976,7 @@ def to_stata(
19751976
variable_labels: Optional[Dict[Label, str]] = None,
19761977
version: Optional[int] = 114,
19771978
convert_strl: Optional[Sequence[Label]] = None,
1979+
compression: Optional[str] = "infer",
19781980
) -> None:
19791981
"""
19801982
Export DataFrame object to Stata dta format.
@@ -2038,6 +2040,14 @@ def to_stata(
20382040
20392041
.. versionadded:: 0.23.0
20402042
2043+
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
2044+
For on-the-fly compression of the output dta. If 'infer', then use
2045+
gzip, bz2, zip or xz if path_or_buf is a string ending in
2046+
'.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
2047+
otherwise.
2048+
2049+
.. versionadded:: 1.2.0
2050+
20412051
Raises
20422052
------
20432053
NotImplementedError
@@ -2093,6 +2103,7 @@ def to_stata(
20932103
data_label=data_label,
20942104
write_index=write_index,
20952105
variable_labels=variable_labels,
2106+
compression=compression,
20962107
**kwargs,
20972108
)
20982109
writer.write_file()

pandas/io/stata.py

+70-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,18 @@
1616
from pathlib import Path
1717
import struct
1818
import sys
19-
from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union
19+
from typing import (
20+
Any,
21+
AnyStr,
22+
BinaryIO,
23+
Dict,
24+
List,
25+
Optional,
26+
Sequence,
27+
Tuple,
28+
Union,
29+
Mapping,
30+
)
2031
import warnings
2132

2233
from dateutil.relativedelta import relativedelta
@@ -47,7 +58,13 @@
4758
from pandas.core.indexes.base import Index
4859
from pandas.core.series import Series
4960

50-
from pandas.io.common import get_filepath_or_buffer, stringify_path
61+
from pandas.io.common import (
62+
get_compression_method,
63+
get_filepath_or_buffer,
64+
stringify_path,
65+
get_handle,
66+
infer_compression,
67+
)
5168

5269
_version_error = (
5370
"Version of given Stata file is {version}. pandas supports importing "
@@ -1854,7 +1871,9 @@ def read_stata(
18541871
return data
18551872

18561873

1857-
def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
1874+
def _open_file_binary_write(
1875+
fname: FilePathOrBuffer, compression: Optional[str]
1876+
) -> Tuple[BinaryIO, bool, Optional[Union[str, Mapping[str, str]]]]:
18581877
"""
18591878
Open a binary file or no-op if file-like.
18601879
@@ -1871,9 +1890,15 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
18711890
"""
18721891
if hasattr(fname, "write"):
18731892
# See https://github.com/python/mypy/issues/1424 for hasattr challenges
1874-
return fname, False # type: ignore
1893+
return fname, False, None # type: ignore
18751894
elif isinstance(fname, (str, Path)):
1876-
return open(fname, "wb"), True
1895+
# Extract compression mode as given, if dict
1896+
compression = infer_compression(fname, compression)
1897+
path_or_buf, _, compression, _ = get_filepath_or_buffer(
1898+
fname, compression=compression
1899+
)
1900+
f, _ = get_handle(path_or_buf, "wb", compression=compression, is_text=False)
1901+
return f, True, compression
18771902
else:
18781903
raise TypeError("fname must be a binary file, buffer or path-like.")
18791904

@@ -2050,6 +2075,13 @@ class StataWriter(StataParser):
20502075
variable_labels : dict
20512076
Dictionary containing columns as keys and variable labels as values.
20522077
Each label must be 80 characters or smaller.
2078+
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
2079+
For on-the-fly compression of the output dta. If 'infer', then use
2080+
gzip, bz2, zip or xz if path_or_buf is a string ending in
2081+
'.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
2082+
otherwise.
2083+
2084+
.. versionadded:: 1.2.0
20532085
20542086
Returns
20552087
-------
@@ -2094,6 +2126,7 @@ def __init__(
20942126
time_stamp: Optional[datetime.datetime] = None,
20952127
data_label: Optional[str] = None,
20962128
variable_labels: Optional[Dict[Label, str]] = None,
2129+
compression: Optional[str] = "infer",
20972130
):
20982131
super().__init__()
20992132
self._convert_dates = {} if convert_dates is None else convert_dates
@@ -2102,6 +2135,8 @@ def __init__(
21022135
self._data_label = data_label
21032136
self._variable_labels = variable_labels
21042137
self._own_file = True
2138+
self._compression = compression
2139+
self._output_file: Optional[BinaryIO] = None
21052140
# attach nobs, nvars, data, varlist, typlist
21062141
self._prepare_pandas(data)
21072142

@@ -2389,7 +2424,12 @@ def _encode_strings(self) -> None:
23892424
self.data[col] = encoded
23902425

23912426
def write_file(self) -> None:
2392-
self._file, self._own_file = _open_file_binary_write(self._fname)
2427+
self._file, self._own_file, compression = _open_file_binary_write(
2428+
self._fname, self._compression
2429+
)
2430+
if compression is not None:
2431+
self._output_file = self._file
2432+
self._file = BytesIO()
23932433
try:
23942434
self._write_header(data_label=self._data_label, time_stamp=self._time_stamp)
23952435
self._write_map()
@@ -2434,6 +2474,12 @@ def _close(self) -> None:
24342474
"""
24352475
# Some file-like objects might not support flush
24362476
assert self._file is not None
2477+
if self._output_file is not None:
2478+
assert isinstance(self._file, BytesIO)
2479+
bio = self._file
2480+
bio.seek(0)
2481+
self._file = self._output_file
2482+
self._file.write(bio.read())
24372483
try:
24382484
self._file.flush()
24392485
except AttributeError:
@@ -2898,6 +2944,13 @@ class StataWriter117(StataWriter):
28982944
Smaller columns can be converted by including the column name. Using
28992945
StrLs can reduce output file size when strings are longer than 8
29002946
characters, and either frequently repeated or sparse.
2947+
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
2948+
For on-the-fly compression of the output dta. If 'infer', then use
2949+
gzip, bz2, zip or xz if path_or_buf is a string ending in
2950+
'.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
2951+
otherwise.
2952+
2953+
.. versionadded:: 1.2.0
29012954
29022955
Returns
29032956
-------
@@ -2946,6 +2999,7 @@ def __init__(
29462999
data_label: Optional[str] = None,
29473000
variable_labels: Optional[Dict[Label, str]] = None,
29483001
convert_strl: Optional[Sequence[Label]] = None,
3002+
compression: Optional[str] = "infer",
29493003
):
29503004
# Copy to new list since convert_strl might be modified later
29513005
self._convert_strl: List[Label] = []
@@ -2961,6 +3015,7 @@ def __init__(
29613015
time_stamp=time_stamp,
29623016
data_label=data_label,
29633017
variable_labels=variable_labels,
3018+
compression=compression,
29643019
)
29653020
self._map: Dict[str, int] = {}
29663021
self._strl_blob = b""
@@ -3281,6 +3336,13 @@ class StataWriterUTF8(StataWriter117):
32813336
The dta version to use. By default, uses the size of data to determine
32823337
the version. 118 is used if data.shape[1] <= 32767, and 119 is used
32833338
for storing larger DataFrames.
3339+
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}, default 'infer'
3340+
For on-the-fly compression of the output dta. If 'infer', then use
3341+
gzip, bz2, zip or xz if path_or_buf is a string ending in
3342+
'.gz', '.bz2', '.zip', or 'xz', respectively, and no compression
3343+
otherwise.
3344+
3345+
.. versionadded:: 1.2.0
32843346
32853347
Returns
32863348
-------
@@ -3331,6 +3393,7 @@ def __init__(
33313393
variable_labels: Optional[Dict[Label, str]] = None,
33323394
convert_strl: Optional[Sequence[Label]] = None,
33333395
version: Optional[int] = None,
3396+
compression: Optional[str] = "infer",
33343397
):
33353398
if version is None:
33363399
version = 118 if data.shape[1] <= 32767 else 119
@@ -3352,6 +3415,7 @@ def __init__(
33523415
data_label=data_label,
33533416
variable_labels=variable_labels,
33543417
convert_strl=convert_strl,
3418+
compression=compression,
33553419
)
33563420
# Override version set in StataWriter117 init
33573421
self._dta_version = version

pandas/tests/io/test_stata.py

+29
Original file line numberDiff line numberDiff line change
@@ -1853,3 +1853,32 @@ def test_writer_118_exceptions(self):
18531853
with tm.ensure_clean() as path:
18541854
with pytest.raises(ValueError, match="You must use version 119"):
18551855
StataWriterUTF8(path, df, version=118)
1856+
1857+
1858+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
1859+
@pytest.mark.parametrize("file_type", ["", "zip", "gz"])
1860+
def test_infer_compression(file_type, version):
1861+
file_name = "dta_inferred_compression.dta"
1862+
if file_type:
1863+
file_name += f".{file_type}"
1864+
df = DataFrame(np.random.randn(10, 2), columns=list("AB"))
1865+
df.index.name = "index"
1866+
with tm.ensure_clean(file_name) as path:
1867+
df.to_stata(path, version=version)
1868+
if file_type == "gz":
1869+
import gzip
1870+
1871+
with gzip.open(path, "rb") as comp:
1872+
reread = read_stata(comp, index_col="index")
1873+
elif file_type == "zip":
1874+
import zipfile
1875+
1876+
zf = zipfile.ZipFile(path, "r")
1877+
for name in zf.namelist():
1878+
bio = io.BytesIO(zf.read(name))
1879+
bio.seek(0)
1880+
reread = read_stata(bio, index_col="index")
1881+
else:
1882+
# No compression
1883+
reread = read_stata(path, index_col="index")
1884+
tm.assert_frame_equal(df, reread)

0 commit comments

Comments
 (0)