Skip to content

Commit e1b657a

Browse files
authored
BUG: sort_values raising for dictionary arrow dtype (pandas-dev#53232)
1 parent f676c5f commit e1b657a

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

doc/source/whatsnew/v2.0.2.rst

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Bug fixes
3232
- Bug in :func:`to_timedelta` was raising ``ValueError`` with ``pandas.NA`` (:issue:`52909`)
3333
- Bug in :meth:`DataFrame.__getitem__` not preserving dtypes for :class:`MultiIndex` partial keys (:issue:`51895`)
3434
- Bug in :meth:`DataFrame.convert_dtypes` ignores ``convert_*`` keywords when set to False ``dtype_backend="pyarrow"`` (:issue:`52872`)
35+
- Bug in :meth:`DataFrame.sort_values` raising for PyArrow ``dictionary`` dtype (:issue:`53232`)
3536
- Bug in :meth:`Series.describe` treating pyarrow-backed timestamps and timedeltas as categorical data (:issue:`53001`)
3637
- Bug in :meth:`Series.rename` not making a lazy copy when Copy-on-Write is enabled when a scalar is passed to it (:issue:`52450`)
3738
- Bug in :meth:`pd.array` raising for ``NumPy`` array and ``pa.large_string`` or ``pa.large_binary`` (:issue:`52590`)

pandas/core/arrays/arrow/array.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,10 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal
266266
# GH50430: let pyarrow infer type, then cast
267267
scalars = pa.array(scalars, from_pandas=True)
268268
if pa_dtype and scalars.type != pa_dtype:
269-
scalars = scalars.cast(pa_dtype)
269+
if pa.types.is_dictionary(pa_dtype):
270+
scalars = scalars.dictionary_encode()
271+
else:
272+
scalars = scalars.cast(pa_dtype)
270273
arr = cls(scalars)
271274
if pa.types.is_duration(scalars.type) and scalars.null_count > 0:
272275
# GH52843: upstream bug for duration types when originally
@@ -878,7 +881,10 @@ def factorize(
878881
else:
879882
data = self._pa_array
880883

881-
encoded = data.dictionary_encode(null_encoding=null_encoding)
884+
if pa.types.is_dictionary(data.type):
885+
encoded = data
886+
else:
887+
encoded = data.dictionary_encode(null_encoding=null_encoding)
882888
if encoded.length() == 0:
883889
indices = np.array([], dtype=np.intp)
884890
uniques = type(self)(pa.chunked_array([], type=encoded.type.value_type))

pandas/tests/extension/test_arrow.py

+14
Original file line numberDiff line numberDiff line change
@@ -1811,6 +1811,20 @@ def test_searchsorted_with_na_raises(data_for_sorting, as_series):
18111811
arr.searchsorted(b)
18121812

18131813

1814+
def test_sort_values_dictionary():
1815+
df = pd.DataFrame(
1816+
{
1817+
"a": pd.Series(
1818+
["x", "y"], dtype=ArrowDtype(pa.dictionary(pa.int32(), pa.string()))
1819+
),
1820+
"b": [1, 2],
1821+
},
1822+
)
1823+
expected = df.copy()
1824+
result = df.sort_values(by=["a", "b"])
1825+
tm.assert_frame_equal(result, expected)
1826+
1827+
18141828
@pytest.mark.parametrize("pat", ["abc", "a[a-z]{2}"])
18151829
def test_str_count(pat):
18161830
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))

0 commit comments

Comments
 (0)