Skip to content

Commit a393c31

Browse files
authored
ENH(string dtype): fallback for HDF5 with UTF-8 surrogates (pandas-dev#60993)
1 parent d739c92 commit a393c31

File tree

2 files changed

+96
-39
lines changed

2 files changed

+96
-39
lines changed

pandas/io/pytables.py

+85-29
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from pandas._libs.lib import is_string_array
4141
from pandas._libs.tslibs import timezones
42+
from pandas.compat import HAS_PYARROW
4243
from pandas.compat._optional import import_optional_dependency
4344
from pandas.compat.pickle_compat import patch_pickle
4445
from pandas.errors import (
@@ -381,6 +382,13 @@ def read_hdf(
381382
DataFrame.to_hdf : Write a HDF file from a DataFrame.
382383
HDFStore : Low-level access to HDF files.
383384
385+
Notes
386+
-----
387+
When ``errors="surrogatepass"``, ``pd.options.future.infer_string`` is true,
388+
and PyArrow is installed, if a UTF-16 surrogate is encountered when decoding
389+
to UTF-8, the resulting dtype will be
390+
``pd.StringDtype(storage="python", na_value=np.nan)``.
391+
384392
Examples
385393
--------
386394
>>> df = pd.DataFrame([[1, 1.0, "a"]], columns=["x", "y", "z"]) # doctest: +SKIP
@@ -2257,6 +2265,20 @@ def convert(
22572265
# making an Index instance could throw a number of different errors
22582266
try:
22592267
new_pd_index = factory(values, **kwargs)
2268+
except UnicodeEncodeError as err:
2269+
if (
2270+
errors == "surrogatepass"
2271+
and get_option("future.infer_string")
2272+
and str(err).endswith("surrogates not allowed")
2273+
and HAS_PYARROW
2274+
):
2275+
new_pd_index = factory(
2276+
values,
2277+
dtype=StringDtype(storage="python", na_value=np.nan),
2278+
**kwargs,
2279+
)
2280+
else:
2281+
raise
22602282
except ValueError:
22612283
# if the output freq is different that what we recorded,
22622284
# it should be None (see also 'doc example part 2')
@@ -3170,12 +3192,29 @@ def read_index_node(
31703192
**kwargs,
31713193
)
31723194
else:
3173-
index = factory(
3174-
_unconvert_index(
3175-
data, kind, encoding=self.encoding, errors=self.errors
3176-
),
3177-
**kwargs,
3178-
)
3195+
try:
3196+
index = factory(
3197+
_unconvert_index(
3198+
data, kind, encoding=self.encoding, errors=self.errors
3199+
),
3200+
**kwargs,
3201+
)
3202+
except UnicodeEncodeError as err:
3203+
if (
3204+
self.errors == "surrogatepass"
3205+
and get_option("future.infer_string")
3206+
and str(err).endswith("surrogates not allowed")
3207+
and HAS_PYARROW
3208+
):
3209+
index = factory(
3210+
_unconvert_index(
3211+
data, kind, encoding=self.encoding, errors=self.errors
3212+
),
3213+
dtype=StringDtype(storage="python", na_value=np.nan),
3214+
**kwargs,
3215+
)
3216+
else:
3217+
raise
31793218

31803219
index.name = name
31813220

@@ -3311,13 +3350,24 @@ def read(
33113350
self.validate_read(columns, where)
33123351
index = self.read_index("index", start=start, stop=stop)
33133352
values = self.read_array("values", start=start, stop=stop)
3314-
result = Series(values, index=index, name=self.name, copy=False)
3315-
if (
3316-
using_string_dtype()
3317-
and isinstance(values, np.ndarray)
3318-
and is_string_array(values, skipna=True)
3319-
):
3320-
result = result.astype(StringDtype(na_value=np.nan))
3353+
try:
3354+
result = Series(values, index=index, name=self.name, copy=False)
3355+
except UnicodeEncodeError as err:
3356+
if (
3357+
self.errors == "surrogatepass"
3358+
and get_option("future.infer_string")
3359+
and str(err).endswith("surrogates not allowed")
3360+
and HAS_PYARROW
3361+
):
3362+
result = Series(
3363+
values,
3364+
index=index,
3365+
name=self.name,
3366+
copy=False,
3367+
dtype=StringDtype(storage="python", na_value=np.nan),
3368+
)
3369+
else:
3370+
raise
33213371
return result
33223372

33233373
def write(self, obj, **kwargs) -> None:
@@ -4764,7 +4814,24 @@ def read(
47644814
values = values.reshape((1, values.shape[0]))
47654815

47664816
if isinstance(values, (np.ndarray, DatetimeArray)):
4767-
df = DataFrame(values.T, columns=cols_, index=index_, copy=False)
4817+
try:
4818+
df = DataFrame(values.T, columns=cols_, index=index_, copy=False)
4819+
except UnicodeEncodeError as err:
4820+
if (
4821+
self.errors == "surrogatepass"
4822+
and get_option("future.infer_string")
4823+
and str(err).endswith("surrogates not allowed")
4824+
and HAS_PYARROW
4825+
):
4826+
df = DataFrame(
4827+
values.T,
4828+
columns=cols_,
4829+
index=index_,
4830+
copy=False,
4831+
dtype=StringDtype(storage="python", na_value=np.nan),
4832+
)
4833+
else:
4834+
raise
47684835
elif isinstance(values, Index):
47694836
df = DataFrame(values, columns=cols_, index=index_)
47704837
else:
@@ -4774,23 +4841,10 @@ def read(
47744841
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
47754842

47764843
# If str / string dtype is stored in meta, use that.
4777-
converted = False
47784844
for column in cols_:
47794845
dtype = getattr(self.table.attrs, f"{column}_meta", None)
47804846
if dtype in ["str", "string"]:
47814847
df[column] = df[column].astype(dtype)
4782-
converted = True
4783-
# Otherwise try inference.
4784-
if (
4785-
not converted
4786-
and using_string_dtype()
4787-
and isinstance(values, np.ndarray)
4788-
and is_string_array(
4789-
values,
4790-
skipna=True,
4791-
)
4792-
):
4793-
df = df.astype(StringDtype(na_value=np.nan))
47944848
frames.append(df)
47954849

47964850
if len(frames) == 1:
@@ -5224,7 +5278,7 @@ def _convert_string_array(data: np.ndarray, encoding: str, errors: str) -> np.nd
52245278
# encode if needed
52255279
if len(data):
52265280
data = (
5227-
Series(data.ravel(), copy=False)
5281+
Series(data.ravel(), copy=False, dtype="object")
52285282
.str.encode(encoding, errors)
52295283
._values.reshape(data.shape)
52305284
)
@@ -5264,7 +5318,9 @@ def _unconvert_string_array(
52645318
dtype = f"U{itemsize}"
52655319

52665320
if isinstance(data[0], bytes):
5267-
ser = Series(data, copy=False).str.decode(encoding, errors=errors)
5321+
ser = Series(data, copy=False).str.decode(
5322+
encoding, errors=errors, dtype="object"
5323+
)
52685324
data = ser.to_numpy()
52695325
data.flags.writeable = True
52705326
else:

pandas/tests/io/pytables/test_store.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import numpy as np
88
import pytest
99

10-
from pandas._config import using_string_dtype
11-
1210
from pandas.compat import PY312
1311

1412
import pandas as pd
@@ -25,7 +23,6 @@
2523
timedelta_range,
2624
)
2725
import pandas._testing as tm
28-
from pandas.conftest import has_pyarrow
2926
from pandas.tests.io.pytables.common import (
3027
_maybe_remove,
3128
ensure_clean_store,
@@ -385,20 +382,24 @@ def test_to_hdf_with_min_itemsize(tmp_path, setup_path):
385382
tm.assert_series_equal(read_hdf(path, "ss4"), concat([df["B"], df2["B"]]))
386383

387384

388-
@pytest.mark.xfail(
389-
using_string_dtype() and has_pyarrow,
390-
reason="TODO(infer_string): can't encode '\ud800': surrogates not allowed",
391-
)
392385
@pytest.mark.parametrize("format", ["fixed", "table"])
393-
def test_to_hdf_errors(tmp_path, format, setup_path):
386+
def test_to_hdf_errors(tmp_path, format, setup_path, using_infer_string):
394387
data = ["\ud800foo"]
395-
ser = Series(data, index=Index(data))
388+
ser = Series(data, index=Index(data, dtype="object"), dtype="object")
396389
path = tmp_path / setup_path
397390
# GH 20835
398391
ser.to_hdf(path, key="table", format=format, errors="surrogatepass")
399392

400393
result = read_hdf(path, "table", errors="surrogatepass")
401-
tm.assert_series_equal(result, ser)
394+
395+
if using_infer_string:
396+
# https://github.com/pandas-dev/pandas/pull/60993
397+
# Surrogates fallback to python storage.
398+
dtype = pd.StringDtype(storage="python", na_value=np.nan)
399+
else:
400+
dtype = "object"
401+
expected = Series(data, index=Index(data, dtype=dtype), dtype=dtype)
402+
tm.assert_series_equal(result, expected)
402403

403404

404405
def test_create_table_index(setup_path):

0 commit comments

Comments
 (0)