diff --git a/doc/source/whatsnew/v1.2.0.rst b/doc/source/whatsnew/v1.2.0.rst index 8b28a4439e1da..44dd5ba122acd 100644 --- a/doc/source/whatsnew/v1.2.0.rst +++ b/doc/source/whatsnew/v1.2.0.rst @@ -295,6 +295,7 @@ I/O - :meth:`to_csv` passes compression arguments for `'gzip'` always to `gzip.GzipFile` (:issue:`28103`) - :meth:`to_csv` did not support zip compression for binary file object not having a filename (:issue: `35058`) - :meth:`to_csv` and :meth:`read_csv` did not honor `compression` and `encoding` for path-like objects that are internally converted to file-like objects (:issue:`35677`, :issue:`26124`, and :issue:`32392`) +- :meth:`to_picke` and :meth:`read_pickle` did not support compression for file-objects (:issue:`26237`, :issue:`29054`, and :issue:`29570`) Plotting ^^^^^^^^ diff --git a/pandas/_typing.py b/pandas/_typing.py index 74bfc9134c3af..b237013ac7805 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -116,7 +116,7 @@ # compression keywords and compression -CompressionDict = Mapping[str, Optional[Union[str, int, bool]]] +CompressionDict = Dict[str, Any] CompressionOptions = Optional[Union[str, CompressionDict]] @@ -138,6 +138,6 @@ class IOargs(Generic[ModeVar, EncodingVar]): filepath_or_buffer: FileOrBuffer encoding: EncodingVar - compression: CompressionOptions + compression: CompressionDict should_close: bool mode: Union[ModeVar, str] diff --git a/pandas/core/frame.py b/pandas/core/frame.py index c48bec9b670ad..1713743b98bff 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -27,7 +27,6 @@ Iterable, Iterator, List, - Mapping, Optional, Sequence, Set, @@ -49,6 +48,7 @@ ArrayLike, Axes, Axis, + CompressionOptions, Dtype, FilePathOrBuffer, FrameOrSeriesUnion, @@ -2062,7 +2062,7 @@ def to_stata( variable_labels: Optional[Dict[Label, str]] = None, version: Optional[int] = 114, convert_strl: Optional[Sequence[Label]] = None, - compression: Union[str, Mapping[str, str], None] = "infer", + compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ) -> None: """ diff --git a/pandas/io/common.py b/pandas/io/common.py index 2b13d54ec3aed..a80b89569f429 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -205,11 +205,13 @@ def get_filepath_or_buffer( """ filepath_or_buffer = stringify_path(filepath_or_buffer) + # handle compression dict + compression_method, compression = get_compression_method(compression) + compression_method = infer_compression(filepath_or_buffer, compression_method) + compression = dict(compression, method=compression_method) + # bz2 and xz do not write the byte order mark for utf-16 and utf-32 # print a warning when writing such files - compression_method = infer_compression( - filepath_or_buffer, get_compression_method(compression)[0] - ) if ( mode and "w" in mode @@ -238,7 +240,7 @@ def get_filepath_or_buffer( content_encoding = req.headers.get("Content-Encoding", None) if content_encoding == "gzip": # Override compression based on Content-Encoding header - compression = "gzip" + compression = {"method": "gzip"} reader = BytesIO(req.read()) req.close() return IOargs( @@ -374,11 +376,7 @@ def get_compression_method( if isinstance(compression, Mapping): compression_args = dict(compression) try: - # error: Incompatible types in assignment (expression has type - # "Union[str, int, None]", variable has type "Optional[str]") - compression_method = compression_args.pop( # type: ignore[assignment] - "method" - ) + compression_method = compression_args.pop("method") except KeyError as err: raise ValueError("If mapping, compression must have key 'method'") from err else: @@ -652,12 +650,8 @@ def __init__( super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type] def write(self, data): - archive_name = self.filename - if self.archive_name is not None: - archive_name = self.archive_name - if archive_name is None: - # ZipFile needs a non-empty string - archive_name = "zip" + # ZipFile needs a non-empty string + archive_name = self.archive_name or self.filename or "zip" super().writestr(archive_name, data) @property diff --git a/pandas/io/formats/csvs.py b/pandas/io/formats/csvs.py index 270caec022fef..15cd5c026c6b6 100644 --- a/pandas/io/formats/csvs.py +++ b/pandas/io/formats/csvs.py @@ -21,12 +21,7 @@ ) from pandas.core.dtypes.missing import notna -from pandas.io.common import ( - get_compression_method, - get_filepath_or_buffer, - get_handle, - infer_compression, -) +from pandas.io.common import get_filepath_or_buffer, get_handle class CSVFormatter: @@ -60,17 +55,15 @@ def __init__( if path_or_buf is None: path_or_buf = StringIO() - # Extract compression mode as given, if dict - compression, self.compression_args = get_compression_method(compression) - self.compression = infer_compression(path_or_buf, compression) - ioargs = get_filepath_or_buffer( path_or_buf, encoding=encoding, - compression=self.compression, + compression=compression, mode=mode, storage_options=storage_options, ) + self.compression = ioargs.compression.pop("method") + self.compression_args = ioargs.compression self.path_or_buf = ioargs.filepath_or_buffer self.should_close = ioargs.should_close self.mode = ioargs.mode diff --git a/pandas/io/json/_json.py b/pandas/io/json/_json.py index 7a3b76ff7e3d0..a4d923fdbe45a 100644 --- a/pandas/io/json/_json.py +++ b/pandas/io/json/_json.py @@ -19,12 +19,7 @@ from pandas.core.construction import create_series_with_explicit_dtype from pandas.core.reshape.concat import concat -from pandas.io.common import ( - get_compression_method, - get_filepath_or_buffer, - get_handle, - infer_compression, -) +from pandas.io.common import get_compression_method, get_filepath_or_buffer, get_handle from pandas.io.json._normalize import convert_to_line_delimits from pandas.io.json._table_schema import build_table_schema, parse_table_schema from pandas.io.parsers import _validate_integer @@ -66,6 +61,7 @@ def to_json( ) path_or_buf = ioargs.filepath_or_buffer should_close = ioargs.should_close + compression = ioargs.compression if lines and orient != "records": raise ValueError("'lines' keyword only valid when 'orient' is records") @@ -616,9 +612,6 @@ def read_json( if encoding is None: encoding = "utf-8" - compression_method, compression = get_compression_method(compression) - compression_method = infer_compression(path_or_buf, compression_method) - compression = dict(compression, method=compression_method) ioargs = get_filepath_or_buffer( path_or_buf, encoding=encoding, diff --git a/pandas/io/parsers.py b/pandas/io/parsers.py index c6ef5221e7ead..a0466c5ac6b57 100644 --- a/pandas/io/parsers.py +++ b/pandas/io/parsers.py @@ -63,12 +63,7 @@ from pandas.core.series import Series from pandas.core.tools import datetimes as tools -from pandas.io.common import ( - get_filepath_or_buffer, - get_handle, - infer_compression, - validate_header_arg, -) +from pandas.io.common import get_filepath_or_buffer, get_handle, validate_header_arg from pandas.io.date_converters import generic_parser # BOM character (byte order mark) @@ -424,9 +419,7 @@ def _read(filepath_or_buffer: FilePathOrBuffer, kwds): if encoding is not None: encoding = re.sub("_", "-", encoding).lower() kwds["encoding"] = encoding - compression = kwds.get("compression", "infer") - compression = infer_compression(filepath_or_buffer, compression) # TODO: get_filepath_or_buffer could return # Union[FilePathOrBuffer, s3fs.S3File, gcsfs.GCSFile] @@ -1976,6 +1969,10 @@ def __init__(self, src, **kwds): encoding = kwds.get("encoding") + # parsers.TextReader doesn't support compression dicts + if isinstance(kwds.get("compression"), dict): + kwds["compression"] = kwds["compression"]["method"] + if kwds.get("compression") is None and encoding: if isinstance(src, str): src = open(src, "rb") diff --git a/pandas/io/pickle.py b/pandas/io/pickle.py index 857a2d1b69be4..655deb5ca3779 100644 --- a/pandas/io/pickle.py +++ b/pandas/io/pickle.py @@ -92,11 +92,8 @@ def to_pickle( mode="wb", storage_options=storage_options, ) - compression = ioargs.compression - if not isinstance(ioargs.filepath_or_buffer, str) and compression == "infer": - compression = None f, fh = get_handle( - ioargs.filepath_or_buffer, "wb", compression=compression, is_text=False + ioargs.filepath_or_buffer, "wb", compression=ioargs.compression, is_text=False ) if protocol < 0: protocol = pickle.HIGHEST_PROTOCOL @@ -196,11 +193,8 @@ def read_pickle( ioargs = get_filepath_or_buffer( filepath_or_buffer, compression=compression, storage_options=storage_options ) - compression = ioargs.compression - if not isinstance(ioargs.filepath_or_buffer, str) and compression == "infer": - compression = None f, fh = get_handle( - ioargs.filepath_or_buffer, "rb", compression=compression, is_text=False + ioargs.filepath_or_buffer, "rb", compression=ioargs.compression, is_text=False ) # 1) try standard library Pickle diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 34d520004cc65..b3b16e04a5d9e 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -16,18 +16,7 @@ from pathlib import Path import struct import sys -from typing import ( - Any, - AnyStr, - BinaryIO, - Dict, - List, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union import warnings from dateutil.relativedelta import relativedelta @@ -58,13 +47,7 @@ from pandas.core.indexes.base import Index from pandas.core.series import Series -from pandas.io.common import ( - get_compression_method, - get_filepath_or_buffer, - get_handle, - infer_compression, - stringify_path, -) +from pandas.io.common import get_filepath_or_buffer, get_handle, stringify_path _version_error = ( "Version of given Stata file is {version}. pandas supports importing " @@ -1976,9 +1959,6 @@ def _open_file_binary_write( return fname, False, None # type: ignore[return-value] elif isinstance(fname, (str, Path)): # Extract compression mode as given, if dict - compression_typ, compression_args = get_compression_method(compression) - compression_typ = infer_compression(fname, compression_typ) - compression = dict(compression_args, method=compression_typ) ioargs = get_filepath_or_buffer( fname, mode="wb", compression=compression, storage_options=storage_options ) @@ -2235,7 +2215,7 @@ def __init__( time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, variable_labels: Optional[Dict[Label, str]] = None, - compression: Union[str, Mapping[str, str], None] = "infer", + compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ): super().__init__() @@ -3118,7 +3098,7 @@ def __init__( data_label: Optional[str] = None, variable_labels: Optional[Dict[Label, str]] = None, convert_strl: Optional[Sequence[Label]] = None, - compression: Union[str, Mapping[str, str], None] = "infer", + compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ): # Copy to new list since convert_strl might be modified later @@ -3523,7 +3503,7 @@ def __init__( variable_labels: Optional[Dict[Label, str]] = None, convert_strl: Optional[Sequence[Label]] = None, version: Optional[int] = None, - compression: Union[str, Mapping[str, str], None] = "infer", + compression: CompressionOptions = "infer", storage_options: StorageOptions = None, ): if version is None: diff --git a/pandas/tests/io/test_pickle.py b/pandas/tests/io/test_pickle.py index 6331113ab8945..d1c6705dd7a6f 100644 --- a/pandas/tests/io/test_pickle.py +++ b/pandas/tests/io/test_pickle.py @@ -14,7 +14,9 @@ import datetime import glob import gzip +import io import os +from pathlib import Path import pickle import shutil from warnings import catch_warnings, simplefilter @@ -486,3 +488,30 @@ def test_read_pickle_with_subclass(): tm.assert_series_equal(result[0], expected[0]) assert isinstance(result[1], MyTz) + + +def test_pickle_binary_object_compression(compression): + """ + Read/write from binary file-objects w/wo compression. + + GH 26237, GH 29054, and GH 29570 + """ + df = tm.makeDataFrame() + + # reference for compression + with tm.ensure_clean() as path: + df.to_pickle(path, compression=compression) + reference = Path(path).read_bytes() + + # write + buffer = io.BytesIO() + df.to_pickle(buffer, compression=compression) + buffer.seek(0) + + # gzip and zip safe the filename: cannot compare the compressed content + assert buffer.getvalue() == reference or compression in ("gzip", "zip") + + # read + read_df = pd.read_pickle(buffer, compression=compression) + buffer.seek(0) + tm.assert_frame_equal(df, read_df)