Skip to content

Commit ba223c8

Browse files
committed
Add Zstandard support to read_pickle/to_pickle
1 parent 0b671ad commit ba223c8

File tree

5 files changed

+56
-20
lines changed

5 files changed

+56
-20
lines changed

doc/source/getting_started/install.rst

+10
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,13 @@ qtpy Clipboard I/O
402402
xclip Clipboard I/O on linux
403403
xsel Clipboard I/O on linux
404404
========================= ================== =============================================================
405+
406+
407+
Compression
408+
^^^^^^^^^^^
409+
410+
========================= ================== =============================================================
411+
Dependency Minimum Version Notes
412+
========================= ================== =============================================================
413+
Zstandard Zstandard compression
414+
========================= ================== =============================================================

pandas/compat/_optional.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"xlwt": "1.3.0",
3535
"xlsxwriter": "1.2.2",
3636
"numba": "0.50.1",
37+
"zstandard": "0.15.2",
3738
}
3839

3940
# A mapping from import name to package name (on PyPI) for packages where

pandas/io/common.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _get_filepath_or_buffer(
241241
----------
242242
filepath_or_buffer : a url, filepath (str, py.path.local or pathlib.Path),
243243
or buffer
244-
compression : {{'gzip', 'bz2', 'zip', 'xz', None}}, optional
244+
compression : {{'gzip', 'bz2', 'zip', 'xz', 'zstd', None}}, optional
245245
encoding : the encoding to use to decode bytes, default is 'utf-8'
246246
mode : str, optional
247247
@@ -420,7 +420,7 @@ def file_path_to_url(path: str) -> str:
420420
return urljoin("file:", pathname2url(path))
421421

422422

423-
_compression_to_extension = {"gzip": ".gz", "bz2": ".bz2", "zip": ".zip", "xz": ".xz"}
423+
_compression_to_extension = {"gzip": ".gz", "bz2": ".bz2", "zip": ".zip", "xz": ".xz", "zstd": ".zst"}
424424

425425

426426
def get_compression_method(
@@ -471,10 +471,10 @@ def infer_compression(
471471
----------
472472
filepath_or_buffer : str or file handle
473473
File path or object.
474-
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', None}
474+
compression : {'infer', 'gzip', 'bz2', 'zip', 'xz', 'zstd', None}
475475
If 'infer' and `filepath_or_buffer` is path-like, then detect
476476
compression from the following extensions: '.gz', '.bz2', '.zip',
477-
or '.xz' (otherwise no compression).
477+
'.xz', or '.zst' (otherwise no compression).
478478
479479
Returns
480480
-------
@@ -556,11 +556,11 @@ def get_handle(
556556
compression : str or dict, default None
557557
If string, specifies compression mode. If dict, value at key 'method'
558558
specifies compression mode. Compression mode must be one of {'infer',
559-
'gzip', 'bz2', 'zip', 'xz', None}. If compression mode is 'infer'
560-
and `filepath_or_buffer` is path-like, then detect compression from
561-
the following extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise
562-
no compression). If dict and compression mode is one of
563-
{'zip', 'gzip', 'bz2'}, or inferred as one of the above,
559+
'gzip', 'bz2', 'zip', 'xz', 'zstd', None}. If compression mode is
560+
'infer' and `filepath_or_buffer` is path-like, then detect compression
561+
from the following extensions: '.gz', '.bz2', '.zip', '.xz', or '.zst'
562+
(otherwise no compression). If dict and compression mode is one of
563+
{'zip', 'gzip', 'bz2', 'zstd'}, or inferred as one of the above,
564564
other entries passed as additional compression options.
565565
566566
.. versionchanged:: 1.0.0
@@ -572,7 +572,7 @@ def get_handle(
572572
.. versionchanged:: 1.1.0
573573
574574
Passing compression options as keys in dict is now
575-
supported for compression modes 'gzip' and 'bz2' as well as 'zip'.
575+
supported for compression modes 'gzip', 'bz2', 'zstd' and 'zip'.
576576
577577
memory_map : bool, default False
578578
See parsers._parser_params for more information.
@@ -689,6 +689,23 @@ def get_handle(
689689
elif compression == "xz":
690690
handle = get_lzma_file(lzma)(handle, ioargs.mode)
691691

692+
# Zstd Compression
693+
elif compression == "zstd":
694+
zstd = import_optional_dependency("zstandard")
695+
open_args = {
696+
arg: compression_args.pop(arg, None)
697+
for arg in ["encoding", "errors", "newline"]
698+
}
699+
if "r" in ioargs.mode:
700+
open_args["dctx"] = zstd.ZstdDecompressor(**compression_args)
701+
else:
702+
open_args["cctx"] = zstd.ZstdCompressor(**compression_args)
703+
handle = zstd.open(
704+
handle,
705+
mode=ioargs.mode,
706+
**open_args,
707+
)
708+
692709
# Unrecognized Compression
693710
else:
694711
msg = f"Unrecognized compression type: {compression}"

pandas/io/pickle.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ def to_pickle(
3737
.. versionchanged:: 1.0.0
3838
Accept URL. URL has to be of S3 or GCS.
3939
40-
compression : {{'infer', 'gzip', 'bz2', 'zip', 'xz', None}}, default 'infer'
40+
compression : {{'infer', 'gzip', 'bz2', 'zip', 'xz', 'zstd', None}},
41+
default 'infer'.
4142
If 'infer' and 'path_or_url' is path-like, then detect compression from
42-
the following extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise no
43-
compression) If 'infer' and 'path_or_url' is not path-like, then use
44-
None (= no decompression).
43+
the following extensions: '.gz', '.bz2', '.zip', '.xz', or '.zst'
44+
(otherwise no compression).
45+
If 'infer' and 'path_or_url' is not path-like, then use
46+
None (= no compression).
4547
protocol : int
4648
Int which indicates which protocol should be used by the pickler,
4749
default HIGHEST_PROTOCOL (see [1], paragraph 12.1.2). The possible
@@ -142,11 +144,13 @@ def read_pickle(
142144
.. versionchanged:: 1.0.0
143145
Accept URL. URL is not limited to S3 and GCS.
144146
145-
compression : {{'infer', 'gzip', 'bz2', 'zip', 'xz', None}}, default 'infer'
147+
compression : {{'infer', 'gzip', 'bz2', 'zip', 'xz', 'zstd', None}},
148+
default 'infer'.
146149
If 'infer' and 'path_or_url' is path-like, then detect compression from
147-
the following extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise no
148-
compression) If 'infer' and 'path_or_url' is not path-like, then use
149-
None (= no decompression).
150+
the following extensions: '.gz', '.bz2', '.zip', '.xz', or '.zst'
151+
(otherwise no compression).
152+
If 'infer' and 'path_or_url' is not path-like, then use
153+
None (= no compression).
150154
151155
{storage_options}
152156

pandas/tests/io/test_pickle.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class TestCompression:
298298
"bz2": ".bz2",
299299
"zip": ".zip",
300300
"xz": ".xz",
301+
"zstd": ".zst",
301302
}
302303

303304
def compress_file(self, src_path, dest_path, compression):
@@ -314,6 +315,9 @@ def compress_file(self, src_path, dest_path, compression):
314315
f.write(src_path, os.path.basename(src_path))
315316
elif compression == "xz":
316317
f = get_lzma_file(lzma)(dest_path, "w")
318+
elif compression == "zstd":
319+
zstd = pytest.importorskip("zstandard")
320+
f = zstd.open(dest_path, "wb")
317321
else:
318322
msg = f"Unrecognized compression type: {compression}"
319323
raise ValueError(msg)
@@ -350,7 +354,7 @@ def test_write_explicit_bad(self, compression, get_random_path):
350354
df = tm.makeDataFrame()
351355
df.to_pickle(path, compression=compression)
352356

353-
@pytest.mark.parametrize("ext", ["", ".gz", ".bz2", ".no_compress", ".xz"])
357+
@pytest.mark.parametrize("ext", ["", ".gz", ".bz2", ".no_compress", ".xz", ".zst"])
354358
def test_write_infer(self, ext, get_random_path):
355359
base = get_random_path
356360
path1 = base + ext
@@ -396,7 +400,7 @@ def test_read_explicit(self, compression, get_random_path):
396400

397401
tm.assert_frame_equal(df, df2)
398402

399-
@pytest.mark.parametrize("ext", ["", ".gz", ".bz2", ".zip", ".no_compress", ".xz"])
403+
@pytest.mark.parametrize("ext", ["", ".gz", ".bz2", ".zip", ".no_compress", ".xz", ".zst"])
400404
def test_read_infer(self, ext, get_random_path):
401405
base = get_random_path
402406
path1 = base + ".raw"

0 commit comments

Comments
 (0)