Skip to content

Commit 9b0e866

Browse files
[backport 2.3.x] ENH: Enable pytables to round-trip with StringDtype (#60663) (#60771)
ENH: Enable pytables to round-trip with StringDtype (#60663) Co-authored-by: William Ayd <[email protected]> (cherry picked from commit 60325b8) Co-authored-by: Richard Shadrach <[email protected]>
1 parent a70b88b commit 9b0e866

File tree

3 files changed

+87
-20
lines changed

3 files changed

+87
-20
lines changed

doc/source/whatsnew/v2.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Other enhancements
3535
- The semantics for the ``copy`` keyword in ``__array__`` methods (i.e. called
3636
when using ``np.array()`` or ``np.asarray()`` on pandas objects) has been
3737
updated to raise FutureWarning with NumPy >= 2 (:issue:`60340`)
38+
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
3839
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
3940
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
4041

pandas/io/pytables.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,16 @@
8686
DatetimeArray,
8787
PeriodArray,
8888
)
89+
from pandas.core.arrays.string_ import BaseStringArray
8990
import pandas.core.common as com
9091
from pandas.core.computation.pytables import (
9192
PyTablesExpr,
9293
maybe_expression,
9394
)
94-
from pandas.core.construction import extract_array
95+
from pandas.core.construction import (
96+
array as pd_array,
97+
extract_array,
98+
)
9599
from pandas.core.indexes.api import ensure_index
96100
from pandas.core.internals import (
97101
ArrayManager,
@@ -2955,6 +2959,9 @@ def read_array(self, key: str, start: int | None = None, stop: int | None = None
29552959

29562960
if isinstance(node, tables.VLArray):
29572961
ret = node[0][start:stop]
2962+
dtype = getattr(attrs, "value_type", None)
2963+
if dtype is not None:
2964+
ret = pd_array(ret, dtype=dtype)
29582965
else:
29592966
dtype = _ensure_decoded(getattr(attrs, "value_type", None))
29602967
shape = getattr(attrs, "shape", None)
@@ -3193,6 +3200,11 @@ def write_array(
31933200
elif lib.is_np_dtype(value.dtype, "m"):
31943201
self._handle.create_array(self.group, key, value.view("i8"))
31953202
getattr(self.group, key)._v_attrs.value_type = "timedelta64"
3203+
elif isinstance(value, BaseStringArray):
3204+
vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom())
3205+
vlarr.append(value.to_numpy())
3206+
node = getattr(self.group, key)
3207+
node._v_attrs.value_type = str(value.dtype)
31963208
elif empty_array:
31973209
self.write_array_empty(key, value)
31983210
else:
@@ -3225,7 +3237,11 @@ def read(
32253237
index = self.read_index("index", start=start, stop=stop)
32263238
values = self.read_array("values", start=start, stop=stop)
32273239
result = Series(values, index=index, name=self.name, copy=False)
3228-
if using_string_dtype() and is_string_array(values, skipna=True):
3240+
if (
3241+
using_string_dtype()
3242+
and isinstance(values, np.ndarray)
3243+
and is_string_array(values, skipna=True)
3244+
):
32293245
result = result.astype(StringDtype(na_value=np.nan))
32303246
return result
32313247

@@ -3294,7 +3310,11 @@ def read(
32943310

32953311
columns = items[items.get_indexer(blk_items)]
32963312
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
3297-
if using_string_dtype() and is_string_array(values, skipna=True):
3313+
if (
3314+
using_string_dtype()
3315+
and isinstance(values, np.ndarray)
3316+
and is_string_array(values, skipna=True)
3317+
):
32983318
df = df.astype(StringDtype(na_value=np.nan))
32993319
dfs.append(df)
33003320

@@ -4682,9 +4702,13 @@ def read(
46824702
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
46834703
if not (using_string_dtype() and values.dtype.kind == "O"):
46844704
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
4685-
if using_string_dtype() and is_string_array(
4686-
values, # type: ignore[arg-type]
4687-
skipna=True,
4705+
if (
4706+
using_string_dtype()
4707+
and isinstance(values, np.ndarray)
4708+
and is_string_array(
4709+
values,
4710+
skipna=True,
4711+
)
46884712
):
46894713
df = df.astype(StringDtype(na_value=np.nan))
46904714
frames.append(df)

pandas/tests/io/pytables/test_put.py

+56-14
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
from pandas._config import using_string_dtype
7-
86
from pandas._libs.tslibs import Timestamp
97

108
import pandas as pd
@@ -26,7 +24,6 @@
2624

2725
pytestmark = [
2826
pytest.mark.single_cpu,
29-
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
3027
]
3128

3229

@@ -54,8 +51,8 @@ def test_api_default_format(tmp_path, setup_path):
5451
with ensure_clean_store(setup_path) as store:
5552
df = DataFrame(
5653
1.1 * np.arange(120).reshape((30, 4)),
57-
columns=Index(list("ABCD"), dtype=object),
58-
index=Index([f"i-{i}" for i in range(30)], dtype=object),
54+
columns=Index(list("ABCD")),
55+
index=Index([f"i-{i}" for i in range(30)]),
5956
)
6057

6158
with pd.option_context("io.hdf.default_format", "fixed"):
@@ -79,8 +76,8 @@ def test_api_default_format(tmp_path, setup_path):
7976
path = tmp_path / setup_path
8077
df = DataFrame(
8178
1.1 * np.arange(120).reshape((30, 4)),
82-
columns=Index(list("ABCD"), dtype=object),
83-
index=Index([f"i-{i}" for i in range(30)], dtype=object),
79+
columns=Index(list("ABCD")),
80+
index=Index([f"i-{i}" for i in range(30)]),
8481
)
8582

8683
with pd.option_context("io.hdf.default_format", "fixed"):
@@ -106,7 +103,7 @@ def test_put(setup_path):
106103
)
107104
df = DataFrame(
108105
np.random.default_rng(2).standard_normal((20, 4)),
109-
columns=Index(list("ABCD"), dtype=object),
106+
columns=Index(list("ABCD")),
110107
index=date_range("2000-01-01", periods=20, freq="B"),
111108
)
112109
store["a"] = ts
@@ -166,7 +163,7 @@ def test_put_compression(setup_path):
166163
with ensure_clean_store(setup_path) as store:
167164
df = DataFrame(
168165
np.random.default_rng(2).standard_normal((10, 4)),
169-
columns=Index(list("ABCD"), dtype=object),
166+
columns=Index(list("ABCD")),
170167
index=date_range("2000-01-01", periods=10, freq="B"),
171168
)
172169

@@ -183,7 +180,7 @@ def test_put_compression(setup_path):
183180
def test_put_compression_blosc(setup_path):
184181
df = DataFrame(
185182
np.random.default_rng(2).standard_normal((10, 4)),
186-
columns=Index(list("ABCD"), dtype=object),
183+
columns=Index(list("ABCD")),
187184
index=date_range("2000-01-01", periods=10, freq="B"),
188185
)
189186

@@ -197,10 +194,20 @@ def test_put_compression_blosc(setup_path):
197194
tm.assert_frame_equal(store["c"], df)
198195

199196

200-
def test_put_mixed_type(setup_path):
197+
def test_put_datetime_ser(setup_path):
198+
# https://github.com/pandas-dev/pandas/pull/60663
199+
ser = Series(3 * [Timestamp("20010102").as_unit("ns")])
200+
with ensure_clean_store(setup_path) as store:
201+
store.put("ser", ser)
202+
expected = ser.copy()
203+
result = store.get("ser")
204+
tm.assert_series_equal(result, expected)
205+
206+
207+
def test_put_mixed_type(setup_path, using_infer_string):
201208
df = DataFrame(
202209
np.random.default_rng(2).standard_normal((10, 4)),
203-
columns=Index(list("ABCD"), dtype=object),
210+
columns=Index(list("ABCD")),
204211
index=date_range("2000-01-01", periods=10, freq="B"),
205212
)
206213
df["obj1"] = "foo"
@@ -220,13 +227,42 @@ def test_put_mixed_type(setup_path):
220227
with ensure_clean_store(setup_path) as store:
221228
_maybe_remove(store, "df")
222229

223-
with tm.assert_produces_warning(pd.errors.PerformanceWarning):
230+
warning = None if using_infer_string else pd.errors.PerformanceWarning
231+
with tm.assert_produces_warning(warning):
224232
store.put("df", df)
225233

226234
expected = store.get("df")
227235
tm.assert_frame_equal(expected, df)
228236

229237

238+
def test_put_str_frame(setup_path, string_dtype_arguments):
239+
# https://github.com/pandas-dev/pandas/pull/60663
240+
dtype = pd.StringDtype(*string_dtype_arguments)
241+
df = DataFrame({"a": pd.array(["x", pd.NA, "y"], dtype=dtype)})
242+
with ensure_clean_store(setup_path) as store:
243+
_maybe_remove(store, "df")
244+
245+
store.put("df", df)
246+
expected_dtype = "str" if dtype.na_value is np.nan else "string"
247+
expected = df.astype(expected_dtype)
248+
result = store.get("df")
249+
tm.assert_frame_equal(result, expected)
250+
251+
252+
def test_put_str_series(setup_path, string_dtype_arguments):
253+
# https://github.com/pandas-dev/pandas/pull/60663
254+
dtype = pd.StringDtype(*string_dtype_arguments)
255+
ser = Series(["x", pd.NA, "y"], dtype=dtype)
256+
with ensure_clean_store(setup_path) as store:
257+
_maybe_remove(store, "df")
258+
259+
store.put("ser", ser)
260+
expected_dtype = "str" if dtype.na_value is np.nan else "string"
261+
expected = ser.astype(expected_dtype)
262+
result = store.get("ser")
263+
tm.assert_series_equal(result, expected)
264+
265+
230266
@pytest.mark.parametrize("format", ["table", "fixed"])
231267
@pytest.mark.parametrize(
232268
"index",
@@ -253,7 +289,7 @@ def test_store_index_types(setup_path, format, index):
253289
tm.assert_frame_equal(df, store["df"])
254290

255291

256-
def test_column_multiindex(setup_path):
292+
def test_column_multiindex(setup_path, using_infer_string):
257293
# GH 4710
258294
# recreate multi-indexes properly
259295

@@ -264,6 +300,12 @@ def test_column_multiindex(setup_path):
264300
expected = df.set_axis(df.index.to_numpy())
265301

266302
with ensure_clean_store(setup_path) as store:
303+
if using_infer_string:
304+
# TODO(infer_string) make this work for string dtype
305+
msg = "Saving a MultiIndex with an extension dtype is not supported."
306+
with pytest.raises(NotImplementedError, match=msg):
307+
store.put("df", df)
308+
return
267309
store.put("df", df)
268310
tm.assert_frame_equal(
269311
store["df"], expected, check_index_type=True, check_column_type=True

0 commit comments

Comments
 (0)