Skip to content

Commit 42a8f82

Browse files
twoertweinKevin D Smith
authored and
Kevin D Smith
committed
BUG/ENH: to_pickle/read_pickle support compression for file ojects (pandas-dev#35736)
1 parent 0f28ef7 commit 42a8f82

File tree

10 files changed

+61
-80
lines changed

10 files changed

+61
-80
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ I/O
295295
- :meth:`to_csv` passes compression arguments for `'gzip'` always to `gzip.GzipFile` (:issue:`28103`)
296296
- :meth:`to_csv` did not support zip compression for binary file object not having a filename (:issue: `35058`)
297297
- :meth:`to_csv` and :meth:`read_csv` did not honor `compression` and `encoding` for path-like objects that are internally converted to file-like objects (:issue:`35677`, :issue:`26124`, and :issue:`32392`)
298+
- :meth:`to_picke` and :meth:`read_pickle` did not support compression for file-objects (:issue:`26237`, :issue:`29054`, and :issue:`29570`)
298299

299300
Plotting
300301
^^^^^^^^

pandas/_typing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116

117117

118118
# compression keywords and compression
119-
CompressionDict = Mapping[str, Optional[Union[str, int, bool]]]
119+
CompressionDict = Dict[str, Any]
120120
CompressionOptions = Optional[Union[str, CompressionDict]]
121121

122122

@@ -138,6 +138,6 @@ class IOargs(Generic[ModeVar, EncodingVar]):
138138

139139
filepath_or_buffer: FileOrBuffer
140140
encoding: EncodingVar
141-
compression: CompressionOptions
141+
compression: CompressionDict
142142
should_close: bool
143143
mode: Union[ModeVar, str]

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
Iterable,
2828
Iterator,
2929
List,
30-
Mapping,
3130
Optional,
3231
Sequence,
3332
Set,
@@ -49,6 +48,7 @@
4948
ArrayLike,
5049
Axes,
5150
Axis,
51+
CompressionOptions,
5252
Dtype,
5353
FilePathOrBuffer,
5454
FrameOrSeriesUnion,
@@ -2062,7 +2062,7 @@ def to_stata(
20622062
variable_labels: Optional[Dict[Label, str]] = None,
20632063
version: Optional[int] = 114,
20642064
convert_strl: Optional[Sequence[Label]] = None,
2065-
compression: Union[str, Mapping[str, str], None] = "infer",
2065+
compression: CompressionOptions = "infer",
20662066
storage_options: StorageOptions = None,
20672067
) -> None:
20682068
"""

pandas/io/common.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,13 @@ def get_filepath_or_buffer(
205205
"""
206206
filepath_or_buffer = stringify_path(filepath_or_buffer)
207207

208+
# handle compression dict
209+
compression_method, compression = get_compression_method(compression)
210+
compression_method = infer_compression(filepath_or_buffer, compression_method)
211+
compression = dict(compression, method=compression_method)
212+
208213
# bz2 and xz do not write the byte order mark for utf-16 and utf-32
209214
# print a warning when writing such files
210-
compression_method = infer_compression(
211-
filepath_or_buffer, get_compression_method(compression)[0]
212-
)
213215
if (
214216
mode
215217
and "w" in mode
@@ -238,7 +240,7 @@ def get_filepath_or_buffer(
238240
content_encoding = req.headers.get("Content-Encoding", None)
239241
if content_encoding == "gzip":
240242
# Override compression based on Content-Encoding header
241-
compression = "gzip"
243+
compression = {"method": "gzip"}
242244
reader = BytesIO(req.read())
243245
req.close()
244246
return IOargs(
@@ -374,11 +376,7 @@ def get_compression_method(
374376
if isinstance(compression, Mapping):
375377
compression_args = dict(compression)
376378
try:
377-
# error: Incompatible types in assignment (expression has type
378-
# "Union[str, int, None]", variable has type "Optional[str]")
379-
compression_method = compression_args.pop( # type: ignore[assignment]
380-
"method"
381-
)
379+
compression_method = compression_args.pop("method")
382380
except KeyError as err:
383381
raise ValueError("If mapping, compression must have key 'method'") from err
384382
else:
@@ -652,12 +650,8 @@ def __init__(
652650
super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type]
653651

654652
def write(self, data):
655-
archive_name = self.filename
656-
if self.archive_name is not None:
657-
archive_name = self.archive_name
658-
if archive_name is None:
659-
# ZipFile needs a non-empty string
660-
archive_name = "zip"
653+
# ZipFile needs a non-empty string
654+
archive_name = self.archive_name or self.filename or "zip"
661655
super().writestr(archive_name, data)
662656

663657
@property

pandas/io/formats/csvs.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,7 @@
2121
)
2222
from pandas.core.dtypes.missing import notna
2323

24-
from pandas.io.common import (
25-
get_compression_method,
26-
get_filepath_or_buffer,
27-
get_handle,
28-
infer_compression,
29-
)
24+
from pandas.io.common import get_filepath_or_buffer, get_handle
3025

3126

3227
class CSVFormatter:
@@ -60,17 +55,15 @@ def __init__(
6055
if path_or_buf is None:
6156
path_or_buf = StringIO()
6257

63-
# Extract compression mode as given, if dict
64-
compression, self.compression_args = get_compression_method(compression)
65-
self.compression = infer_compression(path_or_buf, compression)
66-
6758
ioargs = get_filepath_or_buffer(
6859
path_or_buf,
6960
encoding=encoding,
70-
compression=self.compression,
61+
compression=compression,
7162
mode=mode,
7263
storage_options=storage_options,
7364
)
65+
self.compression = ioargs.compression.pop("method")
66+
self.compression_args = ioargs.compression
7467
self.path_or_buf = ioargs.filepath_or_buffer
7568
self.should_close = ioargs.should_close
7669
self.mode = ioargs.mode

pandas/io/json/_json.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,7 @@
1919
from pandas.core.construction import create_series_with_explicit_dtype
2020
from pandas.core.reshape.concat import concat
2121

22-
from pandas.io.common import (
23-
get_compression_method,
24-
get_filepath_or_buffer,
25-
get_handle,
26-
infer_compression,
27-
)
22+
from pandas.io.common import get_compression_method, get_filepath_or_buffer, get_handle
2823
from pandas.io.json._normalize import convert_to_line_delimits
2924
from pandas.io.json._table_schema import build_table_schema, parse_table_schema
3025
from pandas.io.parsers import _validate_integer
@@ -66,6 +61,7 @@ def to_json(
6661
)
6762
path_or_buf = ioargs.filepath_or_buffer
6863
should_close = ioargs.should_close
64+
compression = ioargs.compression
6965

7066
if lines and orient != "records":
7167
raise ValueError("'lines' keyword only valid when 'orient' is records")
@@ -616,9 +612,6 @@ def read_json(
616612
if encoding is None:
617613
encoding = "utf-8"
618614

619-
compression_method, compression = get_compression_method(compression)
620-
compression_method = infer_compression(path_or_buf, compression_method)
621-
compression = dict(compression, method=compression_method)
622615
ioargs = get_filepath_or_buffer(
623616
path_or_buf,
624617
encoding=encoding,

pandas/io/parsers.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,7 @@
6363
from pandas.core.series import Series
6464
from pandas.core.tools import datetimes as tools
6565

66-
from pandas.io.common import (
67-
get_filepath_or_buffer,
68-
get_handle,
69-
infer_compression,
70-
validate_header_arg,
71-
)
66+
from pandas.io.common import get_filepath_or_buffer, get_handle, validate_header_arg
7267
from pandas.io.date_converters import generic_parser
7368

7469
# BOM character (byte order mark)
@@ -424,9 +419,7 @@ def _read(filepath_or_buffer: FilePathOrBuffer, kwds):
424419
if encoding is not None:
425420
encoding = re.sub("_", "-", encoding).lower()
426421
kwds["encoding"] = encoding
427-
428422
compression = kwds.get("compression", "infer")
429-
compression = infer_compression(filepath_or_buffer, compression)
430423

431424
# TODO: get_filepath_or_buffer could return
432425
# Union[FilePathOrBuffer, s3fs.S3File, gcsfs.GCSFile]
@@ -1976,6 +1969,10 @@ def __init__(self, src, **kwds):
19761969

19771970
encoding = kwds.get("encoding")
19781971

1972+
# parsers.TextReader doesn't support compression dicts
1973+
if isinstance(kwds.get("compression"), dict):
1974+
kwds["compression"] = kwds["compression"]["method"]
1975+
19791976
if kwds.get("compression") is None and encoding:
19801977
if isinstance(src, str):
19811978
src = open(src, "rb")

pandas/io/pickle.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,8 @@ def to_pickle(
9292
mode="wb",
9393
storage_options=storage_options,
9494
)
95-
compression = ioargs.compression
96-
if not isinstance(ioargs.filepath_or_buffer, str) and compression == "infer":
97-
compression = None
9895
f, fh = get_handle(
99-
ioargs.filepath_or_buffer, "wb", compression=compression, is_text=False
96+
ioargs.filepath_or_buffer, "wb", compression=ioargs.compression, is_text=False
10097
)
10198
if protocol < 0:
10299
protocol = pickle.HIGHEST_PROTOCOL
@@ -196,11 +193,8 @@ def read_pickle(
196193
ioargs = get_filepath_or_buffer(
197194
filepath_or_buffer, compression=compression, storage_options=storage_options
198195
)
199-
compression = ioargs.compression
200-
if not isinstance(ioargs.filepath_or_buffer, str) and compression == "infer":
201-
compression = None
202196
f, fh = get_handle(
203-
ioargs.filepath_or_buffer, "rb", compression=compression, is_text=False
197+
ioargs.filepath_or_buffer, "rb", compression=ioargs.compression, is_text=False
204198
)
205199

206200
# 1) try standard library Pickle

pandas/io/stata.py

+5-25
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,7 @@
1616
from pathlib import Path
1717
import struct
1818
import sys
19-
from typing import (
20-
Any,
21-
AnyStr,
22-
BinaryIO,
23-
Dict,
24-
List,
25-
Mapping,
26-
Optional,
27-
Sequence,
28-
Tuple,
29-
Union,
30-
)
19+
from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union
3120
import warnings
3221

3322
from dateutil.relativedelta import relativedelta
@@ -58,13 +47,7 @@
5847
from pandas.core.indexes.base import Index
5948
from pandas.core.series import Series
6049

61-
from pandas.io.common import (
62-
get_compression_method,
63-
get_filepath_or_buffer,
64-
get_handle,
65-
infer_compression,
66-
stringify_path,
67-
)
50+
from pandas.io.common import get_filepath_or_buffer, get_handle, stringify_path
6851

6952
_version_error = (
7053
"Version of given Stata file is {version}. pandas supports importing "
@@ -1976,9 +1959,6 @@ def _open_file_binary_write(
19761959
return fname, False, None # type: ignore[return-value]
19771960
elif isinstance(fname, (str, Path)):
19781961
# Extract compression mode as given, if dict
1979-
compression_typ, compression_args = get_compression_method(compression)
1980-
compression_typ = infer_compression(fname, compression_typ)
1981-
compression = dict(compression_args, method=compression_typ)
19821962
ioargs = get_filepath_or_buffer(
19831963
fname, mode="wb", compression=compression, storage_options=storage_options
19841964
)
@@ -2235,7 +2215,7 @@ def __init__(
22352215
time_stamp: Optional[datetime.datetime] = None,
22362216
data_label: Optional[str] = None,
22372217
variable_labels: Optional[Dict[Label, str]] = None,
2238-
compression: Union[str, Mapping[str, str], None] = "infer",
2218+
compression: CompressionOptions = "infer",
22392219
storage_options: StorageOptions = None,
22402220
):
22412221
super().__init__()
@@ -3118,7 +3098,7 @@ def __init__(
31183098
data_label: Optional[str] = None,
31193099
variable_labels: Optional[Dict[Label, str]] = None,
31203100
convert_strl: Optional[Sequence[Label]] = None,
3121-
compression: Union[str, Mapping[str, str], None] = "infer",
3101+
compression: CompressionOptions = "infer",
31223102
storage_options: StorageOptions = None,
31233103
):
31243104
# Copy to new list since convert_strl might be modified later
@@ -3523,7 +3503,7 @@ def __init__(
35233503
variable_labels: Optional[Dict[Label, str]] = None,
35243504
convert_strl: Optional[Sequence[Label]] = None,
35253505
version: Optional[int] = None,
3526-
compression: Union[str, Mapping[str, str], None] = "infer",
3506+
compression: CompressionOptions = "infer",
35273507
storage_options: StorageOptions = None,
35283508
):
35293509
if version is None:

pandas/tests/io/test_pickle.py

+29
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import datetime
1515
import glob
1616
import gzip
17+
import io
1718
import os
19+
from pathlib import Path
1820
import pickle
1921
import shutil
2022
from warnings import catch_warnings, simplefilter
@@ -486,3 +488,30 @@ def test_read_pickle_with_subclass():
486488

487489
tm.assert_series_equal(result[0], expected[0])
488490
assert isinstance(result[1], MyTz)
491+
492+
493+
def test_pickle_binary_object_compression(compression):
494+
"""
495+
Read/write from binary file-objects w/wo compression.
496+
497+
GH 26237, GH 29054, and GH 29570
498+
"""
499+
df = tm.makeDataFrame()
500+
501+
# reference for compression
502+
with tm.ensure_clean() as path:
503+
df.to_pickle(path, compression=compression)
504+
reference = Path(path).read_bytes()
505+
506+
# write
507+
buffer = io.BytesIO()
508+
df.to_pickle(buffer, compression=compression)
509+
buffer.seek(0)
510+
511+
# gzip and zip safe the filename: cannot compare the compressed content
512+
assert buffer.getvalue() == reference or compression in ("gzip", "zip")
513+
514+
# read
515+
read_df = pd.read_pickle(buffer, compression=compression)
516+
buffer.seek(0)
517+
tm.assert_frame_equal(df, read_df)

0 commit comments

Comments
 (0)