diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index 36dffc3d3378b..593e96960ed34 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -27,7 +27,8 @@ New features Other Enhancements ^^^^^^^^^^^^^^^^^^ - +- :func:`to_pickle` has gained a protocol parameter (:issue:`16252`). By default, +this parameter is set to HIGHEST_PROTOCOL (see , 12.1.2). .. _whatsnew_0210.api_breaking: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 2bc64795b5f20..175fd55e31fb3 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -49,7 +49,7 @@ from pandas import compat from pandas.compat.numpy import function as nv from pandas.compat import (map, zip, lzip, lrange, string_types, - isidentifier, set_function_name) + isidentifier, set_function_name, cPickle as pkl) import pandas.core.nanops as nanops from pandas.util.decorators import Appender, Substitution, deprecate_kwarg from pandas.util.validators import validate_bool_kwarg @@ -1344,7 +1344,8 @@ def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail', if_exists=if_exists, index=index, index_label=index_label, chunksize=chunksize, dtype=dtype) - def to_pickle(self, path, compression='infer'): + def to_pickle(self, path, compression='infer', + protocol=pkl.HIGHEST_PROTOCOL): """ Pickle (serialize) object to input file path. @@ -1356,9 +1357,22 @@ def to_pickle(self, path, compression='infer'): a string representing the compression to use in the output file .. versionadded:: 0.20.0 + protocol : int + Int which indicates which protocol should be used by the pickler, + default HIGHEST_PROTOCOL (see [1], paragraph 12.1.2). The possible + values for this parameter depend on the version of Python. For + Python 2.x, possible values are 0, 1, 2. For Python>=3.0, 3 is a + valid value. For Python >= 3.4, 4 is a valid value.A negative value + for the protocol parameter is equivalent to setting its value to + HIGHEST_PROTOCOL. + + .. [1] https://docs.python.org/3/library/pickle.html + .. versionadded:: 0.21.0 + """ from pandas.io.pickle import to_pickle - return to_pickle(self, path, compression=compression) + return to_pickle(self, path, compression=compression, + protocol=protocol) def to_clipboard(self, excel=None, sep=None, **kwargs): """ diff --git a/pandas/io/pickle.py b/pandas/io/pickle.py index 0f91c407766fb..6f4c714931fc8 100644 --- a/pandas/io/pickle.py +++ b/pandas/io/pickle.py @@ -7,7 +7,7 @@ from pandas.io.common import _get_handle, _infer_compression -def to_pickle(obj, path, compression='infer'): +def to_pickle(obj, path, compression='infer', protocol=pkl.HIGHEST_PROTOCOL): """ Pickle (serialize) object to input file path @@ -20,13 +20,28 @@ def to_pickle(obj, path, compression='infer'): a string representing the compression to use in the output file .. versionadded:: 0.20.0 + protocol : int + Int which indicates which protocol should be used by the pickler, + default HIGHEST_PROTOCOL (see [1], paragraph 12.1.2). The possible + values for this parameter depend on the version of Python. For Python + 2.x, possible values are 0, 1, 2. For Python>=3.0, 3 is a valid value. + For Python >= 3.4, 4 is a valid value. A negative value for the + protocol parameter is equivalent to setting its value to + HIGHEST_PROTOCOL. + + .. [1] https://docs.python.org/3/library/pickle.html + .. versionadded:: 0.21.0 + + """ inferred_compression = _infer_compression(path, compression) f, fh = _get_handle(path, 'wb', compression=inferred_compression, is_text=False) + if protocol < 0: + protocol = pkl.HIGHEST_PROTOCOL try: - pkl.dump(obj, f, protocol=pkl.HIGHEST_PROTOCOL) + pkl.dump(obj, f, protocol=protocol) finally: for _f in fh: _f.close() diff --git a/pandas/tests/io/test_pickle.py b/pandas/tests/io/test_pickle.py index 875b5bd3055b9..b290a6f943d91 100644 --- a/pandas/tests/io/test_pickle.py +++ b/pandas/tests/io/test_pickle.py @@ -25,6 +25,7 @@ import pandas.util.testing as tm from pandas.tseries.offsets import Day, MonthEnd import shutil +import sys @pytest.fixture(scope='module') @@ -489,3 +490,38 @@ def test_read_infer(self, ext, get_random_path): df2 = pd.read_pickle(p2) tm.assert_frame_equal(df, df2) + + +# --------------------- +# test pickle compression +# --------------------- + +class TestProtocol(object): + + @pytest.mark.parametrize('protocol', [-1, 0, 1, 2]) + def test_read(self, protocol, get_random_path): + with tm.ensure_clean(get_random_path) as path: + df = tm.makeDataFrame() + df.to_pickle(path, protocol=protocol) + df2 = pd.read_pickle(path) + tm.assert_frame_equal(df, df2) + + @pytest.mark.parametrize('protocol', [3, 4]) + @pytest.mark.skipif(sys.version_info[:2] >= (3, 4), + reason="Testing invalid parameters for " + "Python 2.x and 3.y (y < 4).") + def test_read_bad_versions(self, protocol, get_random_path): + # For Python 2.x (respectively 3.y with y < 4), [expected] + # HIGHEST_PROTOCOL should be 2 (respectively 3). Hence, the protocol + # parameter should not exceed 2 (respectively 3). + if sys.version_info[:2] < (3, 0): + expect_hp = 2 + else: + expect_hp = 3 + with tm.assert_raises_regex(ValueError, + "pickle protocol %d asked for; the highest" + " available protocol is %d" % (protocol, + expect_hp)): + with tm.ensure_clean(get_random_path) as path: + df = tm.makeDataFrame() + df.to_pickle(path, protocol=protocol)