From 91c86993a12154d253c195c2e8b4df4ba649483a Mon Sep 17 00:00:00 2001 From: richard Date: Sun, 5 Jan 2025 15:56:20 -0500 Subject: [PATCH 1/6] ENH: Enable pytables to round-trip with StringDtype --- pandas/io/pytables.py | 55 ++++++++++++++++++++--- pandas/tests/io/pytables/test_put.py | 66 ++++++++++++++++++++++------ 2 files changed, 101 insertions(+), 20 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index b75dc6c3a43b4..0a51b80edb53c 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -38,6 +38,7 @@ writers as libwriters, ) from pandas._libs.lib import is_string_array +from pandas._libs.missing import NA from pandas._libs.tslibs import timezones from pandas.compat._optional import import_optional_dependency from pandas.compat.pickle_compat import patch_pickle @@ -91,7 +92,10 @@ PyTablesExpr, maybe_expression, ) -from pandas.core.construction import extract_array +from pandas.core.construction import ( + array as pd_array, + extract_array, +) from pandas.core.indexes.api import ensure_index from pandas.io.common import stringify_path @@ -3023,6 +3027,18 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None if isinstance(node, tables.VLArray): ret = node[0][start:stop] + dtype = getattr(attrs, "value_type", None) + if dtype is not None: + if dtype == "str[python]": + dtype = StringDtype("python", np.nan) + elif dtype == "string[python]": + dtype = StringDtype("python", NA) + elif dtype == "str[pyarrow]": + dtype = StringDtype("pyarrow", np.nan) + else: + assert dtype == "string[pyarrow]" + dtype = StringDtype("pyarrow", NA) + ret = pd_array(ret, dtype=dtype) else: dtype = getattr(attrs, "value_type", None) shape = getattr(attrs, "shape", None) @@ -3210,6 +3226,8 @@ def write_array( # get the atom for this datatype atom = _tables().Atom.from_dtype(value.dtype) + from pandas.core.arrays.string_ import BaseStringArray + if atom is not None: # We only get here if self._filters is non-None and # the Atom.from_dtype call succeeded @@ -3262,6 +3280,19 @@ def write_array( elif lib.is_np_dtype(value.dtype, "m"): self._handle.create_array(self.group, key, value.view("i8")) getattr(self.group, key)._v_attrs.value_type = "timedelta64" + elif isinstance(value, BaseStringArray): + vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) + vlarr.append(value.to_numpy()) + node = getattr(self.group, key) + if value.dtype == StringDtype("python", np.nan): + node._v_attrs.value_type = "str[python]" + elif value.dtype == StringDtype("python", NA): + node._v_attrs.value_type = "string[python]" + elif value.dtype == StringDtype("pyarrow", np.nan): + node._v_attrs.value_type = "str[pyarrow]" + else: + assert value.dtype == StringDtype("pyarrow", NA) + node._v_attrs.value_type = "string[pyarrow]" elif empty_array: self.write_array_empty(key, value) else: @@ -3294,7 +3325,11 @@ def read( index = self.read_index("index", start=start, stop=stop) values = self.read_array("values", start=start, stop=stop) result = Series(values, index=index, name=self.name, copy=False) - if using_string_dtype() and is_string_array(values, skipna=True): + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): result = result.astype(StringDtype(na_value=np.nan)) return result @@ -3363,7 +3398,11 @@ def read( columns = items[items.get_indexer(blk_items)] df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) - if using_string_dtype() and is_string_array(values, skipna=True): + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): df = df.astype(StringDtype(na_value=np.nan)) dfs.append(df) @@ -4737,9 +4776,13 @@ def read( df = DataFrame._from_arrays([values], columns=cols_, index=index_) if not (using_string_dtype() and values.dtype.kind == "O"): assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) - if using_string_dtype() and is_string_array( - values, # type: ignore[arg-type] - skipna=True, + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array( + values, # type: ignore[arg-type] + skipna=True, + ) ): df = df.astype(StringDtype(na_value=np.nan)) frames.append(df) diff --git a/pandas/tests/io/pytables/test_put.py b/pandas/tests/io/pytables/test_put.py index a4257b54dd6db..6a80fa224ce85 100644 --- a/pandas/tests/io/pytables/test_put.py +++ b/pandas/tests/io/pytables/test_put.py @@ -3,8 +3,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import Timestamp import pandas as pd @@ -26,7 +24,6 @@ pytestmark = [ pytest.mark.single_cpu, - pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False), ] @@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path): with ensure_clean_store(setup_path) as store: df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=Index(list("ABCD"), dtype=object), - index=Index([f"i-{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD")), + index=Index([f"i-{i}" for i in range(30)]), ) with pd.option_context("io.hdf.default_format", "fixed"): @@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path): path = tmp_path / setup_path df = DataFrame( 1.1 * np.arange(120).reshape((30, 4)), - columns=Index(list("ABCD"), dtype=object), - index=Index([f"i-{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD")), + index=Index([f"i-{i}" for i in range(30)]), ) with pd.option_context("io.hdf.default_format", "fixed"): @@ -106,7 +103,7 @@ def test_put(setup_path): ) df = DataFrame( np.random.default_rng(2).standard_normal((20, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=20, freq="B"), ) store["a"] = ts @@ -166,7 +163,7 @@ def test_put_compression(setup_path): with ensure_clean_store(setup_path) as store: df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) @@ -183,7 +180,7 @@ def test_put_compression(setup_path): def test_put_compression_blosc(setup_path): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) @@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path): tm.assert_frame_equal(store["c"], df) -def test_put_mixed_type(setup_path, performance_warning): +def test_put_datetime_ser(setup_path, performance_warning, using_infer_string): + # https://github.com/pandas-dev/pandas/pull/??? + ser = Series(3 * [Timestamp("20010102").as_unit("ns")]) + with ensure_clean_store(setup_path) as store: + store.put("ser", ser) + expected = ser.copy() + result = store.get("ser") + tm.assert_series_equal(result, expected) + + +def test_put_mixed_type(setup_path, performance_warning, using_infer_string): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), - columns=Index(list("ABCD"), dtype=object), + columns=Index(list("ABCD")), index=date_range("2000-01-01", periods=10, freq="B"), ) df["obj1"] = "foo" @@ -220,13 +227,38 @@ def test_put_mixed_type(setup_path, performance_warning): with ensure_clean_store(setup_path) as store: _maybe_remove(store, "df") - with tm.assert_produces_warning(performance_warning): + warning = None if using_infer_string else performance_warning + with tm.assert_produces_warning(warning): store.put("df", df) expected = store.get("df") tm.assert_frame_equal(expected, df) +def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments): + dtype = pd.StringDtype(*string_dtype_arguments) + df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)}) + with ensure_clean_store(setup_path) as store: + _maybe_remove(store, "df") + + store.put("df", df) + expected = df + result = store.get("df") + tm.assert_frame_equal(result, expected) + + +def test_put_str_series(setup_path, performance_warning, string_dtype_arguments): + dtype = pd.StringDtype(*string_dtype_arguments) + ser = Series(["x", pd.NA, "y"], dtype=dtype) + with ensure_clean_store(setup_path) as store: + _maybe_remove(store, "df") + + store.put("ser", ser) + expected = ser + result = store.get("ser") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("format", ["table", "fixed"]) @pytest.mark.parametrize( "index", @@ -253,7 +285,7 @@ def test_store_index_types(setup_path, format, index): tm.assert_frame_equal(df, store["df"]) -def test_column_multiindex(setup_path): +def test_column_multiindex(setup_path, using_infer_string): # GH 4710 # recreate multi-indexes properly @@ -264,6 +296,12 @@ def test_column_multiindex(setup_path): expected = df.set_axis(df.index.to_numpy()) with ensure_clean_store(setup_path) as store: + if using_infer_string: + # TODO(infer_string) make this work for string dtype + msg = "Saving a MultiIndex with an extension dtype is not supported." + with pytest.raises(NotImplementedError, match=msg): + store.put("df", df) + return store.put("df", df) tm.assert_frame_equal( store["df"], expected, check_index_type=True, check_column_type=True From 621ea78878496e0316110feb692b4670f38f17f2 Mon Sep 17 00:00:00 2001 From: richard Date: Sun, 5 Jan 2025 16:00:20 -0500 Subject: [PATCH 2/6] Cleanups --- pandas/io/pytables.py | 3 +-- pandas/tests/io/pytables/test_put.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 0a51b80edb53c..20be83c975497 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -87,6 +87,7 @@ PeriodArray, ) from pandas.core.arrays.datetimes import tz_to_dtype +from pandas.core.arrays.string_ import BaseStringArray import pandas.core.common as com from pandas.core.computation.pytables import ( PyTablesExpr, @@ -3226,8 +3227,6 @@ def write_array( # get the atom for this datatype atom = _tables().Atom.from_dtype(value.dtype) - from pandas.core.arrays.string_ import BaseStringArray - if atom is not None: # We only get here if self._filters is non-None and # the Atom.from_dtype call succeeded diff --git a/pandas/tests/io/pytables/test_put.py b/pandas/tests/io/pytables/test_put.py index 6a80fa224ce85..3c764b2298f86 100644 --- a/pandas/tests/io/pytables/test_put.py +++ b/pandas/tests/io/pytables/test_put.py @@ -195,7 +195,7 @@ def test_put_compression_blosc(setup_path): def test_put_datetime_ser(setup_path, performance_warning, using_infer_string): - # https://github.com/pandas-dev/pandas/pull/??? + # https://github.com/pandas-dev/pandas/pull/60663 ser = Series(3 * [Timestamp("20010102").as_unit("ns")]) with ensure_clean_store(setup_path) as store: store.put("ser", ser) @@ -236,6 +236,7 @@ def test_put_mixed_type(setup_path, performance_warning, using_infer_string): def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments): + # https://github.com/pandas-dev/pandas/pull/60663 dtype = pd.StringDtype(*string_dtype_arguments) df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)}) with ensure_clean_store(setup_path) as store: @@ -248,6 +249,7 @@ def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments): def test_put_str_series(setup_path, performance_warning, string_dtype_arguments): + # https://github.com/pandas-dev/pandas/pull/60663 dtype = pd.StringDtype(*string_dtype_arguments) ser = Series(["x", pd.NA, "y"], dtype=dtype) with ensure_clean_store(setup_path) as store: From 11f8c0db74afce757f48966e7bd812da9be093f1 Mon Sep 17 00:00:00 2001 From: richard Date: Sun, 5 Jan 2025 17:06:48 -0500 Subject: [PATCH 3/6] Remove type-ignore --- pandas/io/pytables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 20be83c975497..342ada2e095dd 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -4779,7 +4779,7 @@ def read( using_string_dtype() and isinstance(values, np.ndarray) and is_string_array( - values, # type: ignore[arg-type] + values, skipna=True, ) ): From 52dbb5e67a632592dd5d4d6e21383056d87685bc Mon Sep 17 00:00:00 2001 From: Richard Shadrach Date: Wed, 8 Jan 2025 16:45:39 -0500 Subject: [PATCH 4/6] Add whatsnew note --- doc/source/whatsnew/v2.3.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index b107a5d3ba100..0a0e1dbb1380e 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -35,8 +35,8 @@ Other enhancements - The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been updated to work correctly with NumPy >= 2 (:issue:`57739`) +- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`) -- .. --------------------------------------------------------------------------- .. _whatsnew_230.notable_bug_fixes: From 4c20a6aae45c8daafa1fd5d816a81285e8082c7e Mon Sep 17 00:00:00 2001 From: richard Date: Tue, 21 Jan 2025 20:51:37 -0500 Subject: [PATCH 5/6] Rework roundtripping logic --- doc/source/whatsnew/v2.3.0.rst | 2 +- pandas/io/pytables.py | 20 +------------------- pandas/tests/io/pytables/test_put.py | 6 ++++-- 3 files changed, 6 insertions(+), 22 deletions(-) diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index e7e6d8d7d2610..ef4d97d11225d 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -35,7 +35,7 @@ Other enhancements - The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been updated to work correctly with NumPy >= 2 (:issue:`57739`) -- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) +- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` preserving the ``na_value`` but not necessarily the storage (:issue:`60663`) - The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index 342ada2e095dd..2f8096746318b 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -38,7 +38,6 @@ writers as libwriters, ) from pandas._libs.lib import is_string_array -from pandas._libs.missing import NA from pandas._libs.tslibs import timezones from pandas.compat._optional import import_optional_dependency from pandas.compat.pickle_compat import patch_pickle @@ -3030,15 +3029,6 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None ret = node[0][start:stop] dtype = getattr(attrs, "value_type", None) if dtype is not None: - if dtype == "str[python]": - dtype = StringDtype("python", np.nan) - elif dtype == "string[python]": - dtype = StringDtype("python", NA) - elif dtype == "str[pyarrow]": - dtype = StringDtype("pyarrow", np.nan) - else: - assert dtype == "string[pyarrow]" - dtype = StringDtype("pyarrow", NA) ret = pd_array(ret, dtype=dtype) else: dtype = getattr(attrs, "value_type", None) @@ -3283,15 +3273,7 @@ def write_array( vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) vlarr.append(value.to_numpy()) node = getattr(self.group, key) - if value.dtype == StringDtype("python", np.nan): - node._v_attrs.value_type = "str[python]" - elif value.dtype == StringDtype("python", NA): - node._v_attrs.value_type = "string[python]" - elif value.dtype == StringDtype("pyarrow", np.nan): - node._v_attrs.value_type = "str[pyarrow]" - else: - assert value.dtype == StringDtype("pyarrow", NA) - node._v_attrs.value_type = "string[pyarrow]" + node._v_attrs.value_type = str(value.dtype) elif empty_array: self.write_array_empty(key, value) else: diff --git a/pandas/tests/io/pytables/test_put.py b/pandas/tests/io/pytables/test_put.py index 3c764b2298f86..66596f1138b96 100644 --- a/pandas/tests/io/pytables/test_put.py +++ b/pandas/tests/io/pytables/test_put.py @@ -243,7 +243,8 @@ def test_put_str_frame(setup_path, performance_warning, string_dtype_arguments): _maybe_remove(store, "df") store.put("df", df) - expected = df + expected_dtype = "str" if dtype.na_value is np.nan else "string" + expected = df.astype(expected_dtype) result = store.get("df") tm.assert_frame_equal(result, expected) @@ -256,7 +257,8 @@ def test_put_str_series(setup_path, performance_warning, string_dtype_arguments) _maybe_remove(store, "df") store.put("ser", ser) - expected = ser + expected_dtype = "str" if dtype.na_value is np.nan else "string" + expected = ser.astype(expected_dtype) result = store.get("ser") tm.assert_series_equal(result, expected) From 275c3b89841e00666c8a26aefff7689f7be83e8a Mon Sep 17 00:00:00 2001 From: Richard Shadrach <45562402+rhshadrach@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:44:27 -0500 Subject: [PATCH 6/6] Update doc/source/whatsnew/v2.3.0.rst Co-authored-by: William Ayd --- doc/source/whatsnew/v2.3.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index ef4d97d11225d..bed4f5d74f464 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -35,7 +35,7 @@ Other enhancements - The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been updated to work correctly with NumPy >= 2 (:issue:`57739`) -- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` preserving the ``na_value`` but not necessarily the storage (:issue:`60663`) +- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) - The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)