diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index 0b7c2bac1be6a..c4e01a86ce843 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -37,6 +37,7 @@ Other enhancements updated to raise FutureWarning with NumPy >= 2 (:issue:`60340`) - :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`) - :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`) +- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`) - The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`) - The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 3c4cf60ab262a..c0e458f7968e7 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -33,6 +33,7 @@ is_list_like, is_object_dtype, is_re, + is_string_dtype, ) from pandas.core.dtypes.dtypes import ( ArrowDtype, @@ -1981,7 +1982,9 @@ def slice_replace(self, start=None, stop=None, repl=None): result = self._data.array._str_slice_replace(start, stop, repl) return self._wrap_result(result) - def decode(self, encoding, errors: str = "strict"): + def decode( + self, encoding, errors: str = "strict", dtype: str | DtypeObj | None = None + ): """ Decode character string in the Series/Index using indicated encoding. @@ -1992,6 +1995,14 @@ def decode(self, encoding, errors: str = "strict"): ---------- encoding : str errors : str, optional + Specifies the error handling scheme. + Possible values are those supported by :meth:`bytes.decode`. + dtype : str or dtype, optional + The dtype of the result. When not ``None``, must be either a string or + object dtype. When ``None``, the dtype of the result is determined by + ``pd.options.future.infer_string``. + + .. versionadded:: 2.3.0 Returns ------- @@ -2008,6 +2019,10 @@ def decode(self, encoding, errors: str = "strict"): 2 () dtype: object """ + if dtype is not None and not is_string_dtype(dtype): + raise ValueError(f"dtype must be string or object, got {dtype=}") + if dtype is None and get_option("future.infer_string"): + dtype = "str" # TODO: Add a similar _bytes interface. if encoding in _cpython_optimized_decoders: # CPython optimized implementation @@ -2017,7 +2032,6 @@ def decode(self, encoding, errors: str = "strict"): f = lambda x: decoder(x, errors)[0] arr = self._data.array result = arr._str_map(f) - dtype = "str" if get_option("future.infer_string") else None return self._wrap_result(result, dtype=dtype) @forbid_nonstring_types(["bytes"]) diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index d93a3f26934a0..65f95dab7b42f 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -4093,6 +4093,8 @@ def _create_axes( ordered = data_converted.ordered meta = "category" metadata = np.asarray(data_converted.categories).ravel() + elif isinstance(blk.dtype, StringDtype): + meta = str(blk.dtype) data, dtype_name = _get_data_and_dtype_name(data_converted) @@ -4360,7 +4362,9 @@ def read_column( encoding=self.encoding, errors=self.errors, ) - return Series(_set_tz(col_values[1], a.tz), name=column, copy=False) + cvs = _set_tz(col_values[1], a.tz) + dtype = getattr(self.table.attrs, f"{column}_meta", None) + return Series(cvs, name=column, copy=False, dtype=dtype) raise KeyError(f"column [{column}] not found in the table") @@ -4708,8 +4712,18 @@ def read( df = DataFrame._from_arrays([values], columns=cols_, index=index_) if not (using_string_dtype() and values.dtype.kind == "O"): assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) + + # If str / string dtype is stored in meta, use that. + converted = False + for column in cols_: + dtype = getattr(self.table.attrs, f"{column}_meta", None) + if dtype in ["str", "string"]: + df[column] = df[column].astype(dtype) + converted = True + # Otherwise try inference. if ( - using_string_dtype() + not converted + and using_string_dtype() and isinstance(values, np.ndarray) and is_string_array( values, diff --git a/pandas/tests/io/pytables/test_append.py b/pandas/tests/io/pytables/test_append.py index 39c203c558a5b..d0246c8f58d6a 100644 --- a/pandas/tests/io/pytables/test_append.py +++ b/pandas/tests/io/pytables/test_append.py @@ -5,8 +5,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import Timestamp import pandas.util._test_decorators as td @@ -507,7 +505,6 @@ def test_append_with_empty_string(setup_path): tm.assert_frame_equal(store.select("df"), df) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_append_with_data_columns(setup_path): with ensure_clean_store(setup_path) as store: df = DataFrame( diff --git a/pandas/tests/io/pytables/test_categorical.py b/pandas/tests/io/pytables/test_categorical.py index a875e19ea7f0e..449bc5cf1fc57 100644 --- a/pandas/tests/io/pytables/test_categorical.py +++ b/pandas/tests/io/pytables/test_categorical.py @@ -1,8 +1,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas import ( Categorical, DataFrame, @@ -140,7 +138,6 @@ def test_categorical(setup_path): store.select("df3/meta/s/meta") -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_categorical_conversion(tmp_path, setup_path): # GH13322 # Check that read_hdf with categorical columns doesn't return rows if diff --git a/pandas/tests/io/pytables/test_read.py b/pandas/tests/io/pytables/test_read.py index bfebf18c0e0ab..5bec673ad3c70 100644 --- a/pandas/tests/io/pytables/test_read.py +++ b/pandas/tests/io/pytables/test_read.py @@ -5,8 +5,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import Timestamp from pandas.compat import is_platform_windows @@ -74,7 +72,6 @@ def test_read_missing_key_opened_store(tmp_path, setup_path): read_hdf(store, "k1") -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_read_column(setup_path): df = DataFrame( np.random.default_rng(2).standard_normal((10, 4)), diff --git a/pandas/tests/io/pytables/test_select.py b/pandas/tests/io/pytables/test_select.py index f781b6756fec9..e76934745f004 100644 --- a/pandas/tests/io/pytables/test_select.py +++ b/pandas/tests/io/pytables/test_select.py @@ -1,8 +1,6 @@ import numpy as np import pytest -from pandas._config import using_string_dtype - from pandas._libs.tslibs import Timestamp import pandas as pd @@ -651,7 +649,6 @@ def test_frame_select(setup_path): # store.select('frame', [crit1, crit2]) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_frame_select_complex(setup_path): # select via complex criteria @@ -965,7 +962,6 @@ def test_query_long_float_literal(setup_path): tm.assert_frame_equal(expected, result) -@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)") def test_query_compare_column_type(setup_path): # GH 15492 df = DataFrame( diff --git a/pandas/tests/strings/test_strings.py b/pandas/tests/strings/test_strings.py index 59a06a421f53e..c729b910d05a7 100644 --- a/pandas/tests/strings/test_strings.py +++ b/pandas/tests/strings/test_strings.py @@ -599,6 +599,30 @@ def test_decode_errors_kwarg(): tm.assert_series_equal(result, expected) +def test_decode_string_dtype(string_dtype): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", b"b"]) + result = ser.str.decode("utf-8", dtype=string_dtype) + expected = Series(["a", "b"], dtype=string_dtype) + tm.assert_series_equal(result, expected) + + +def test_decode_object_dtype(object_dtype): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", rb"\ud800"]) + result = ser.str.decode("utf-8", dtype=object_dtype) + expected = Series(["a", r"\ud800"], dtype=object_dtype) + tm.assert_series_equal(result, expected) + + +def test_decode_bad_dtype(): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", b"b"]) + msg = "dtype must be string or object, got dtype='int64'" + with pytest.raises(ValueError, match=msg): + ser.str.decode("utf-8", dtype="int64") + + @pytest.mark.parametrize( "form, expected", [