diff --git a/doc/source/whatsnew/v1.0.2.rst b/doc/source/whatsnew/v1.0.2.rst index 1b6098e6b6ac1..808e6ae709ce9 100644 --- a/doc/source/whatsnew/v1.0.2.rst +++ b/doc/source/whatsnew/v1.0.2.rst @@ -74,6 +74,7 @@ Bug fixes **I/O** - Using ``pd.NA`` with :meth:`DataFrame.to_json` now correctly outputs a null value instead of an empty object (:issue:`31615`) +- Fixed pickling of ``pandas.NA``. Previously a new object was returned, which broke computations relying on ``NA`` being a singleton (:issue:`31847`) - Fixed bug in parquet roundtrip with nullable unsigned integer dtypes (:issue:`31896`). **Experimental dtypes** diff --git a/pandas/_libs/missing.pyx b/pandas/_libs/missing.pyx index 4d17a6f883c1c..c54cb652d7b21 100644 --- a/pandas/_libs/missing.pyx +++ b/pandas/_libs/missing.pyx @@ -364,6 +364,9 @@ class NAType(C_NAType): exponent = 31 if is_32bit else 61 return 2 ** exponent - 1 + def __reduce__(self): + return "NA" + # Binary arithmetic and comparison ops -> propagate __add__ = _create_binary_propagating_op("__add__") diff --git a/pandas/tests/scalar/test_na_scalar.py b/pandas/tests/scalar/test_na_scalar.py index dcb9d66708724..07656de2e9062 100644 --- a/pandas/tests/scalar/test_na_scalar.py +++ b/pandas/tests/scalar/test_na_scalar.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest @@ -267,3 +269,26 @@ def test_integer_hash_collision_set(): assert len(result) == 2 assert NA in result assert hash(NA) in result + + +def test_pickle_roundtrip(): + # https://github.com/pandas-dev/pandas/issues/31847 + result = pickle.loads(pickle.dumps(pd.NA)) + assert result is pd.NA + + +def test_pickle_roundtrip_pandas(): + result = tm.round_trip_pickle(pd.NA) + assert result is pd.NA + + +@pytest.mark.parametrize( + "values, dtype", [([1, 2, pd.NA], "Int64"), (["A", "B", pd.NA], "string")] +) +@pytest.mark.parametrize("as_frame", [True, False]) +def test_pickle_roundtrip_containers(as_frame, values, dtype): + s = pd.Series(pd.array(values, dtype=dtype)) + if as_frame: + s = s.to_frame(name="A") + result = tm.round_trip_pickle(s) + tm.assert_equal(result, s)