Skip to content

Commit 1dcdffd

Browse files
authored
BUG: dictionary type astype categorical using dictionary as categories (#56672)
1 parent 9e87dc7 commit 1dcdffd

File tree

3 files changed

+45
-18
lines changed

3 files changed

+45
-18
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ Categorical
740740
^^^^^^^^^^^
741741
- :meth:`Categorical.isin` raising ``InvalidIndexError`` for categorical containing overlapping :class:`Interval` values (:issue:`34974`)
742742
- Bug in :meth:`CategoricalDtype.__eq__` returning ``False`` for unordered categorical data with mixed types (:issue:`55468`)
743+
- Bug when casting ``pa.dictionary`` to :class:`CategoricalDtype` using a ``pa.DictionaryArray`` as categories (:issue:`56672`)
743744

744745
Datetimelike
745746
^^^^^^^^^^^^

pandas/core/arrays/categorical.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
pandas_dtype,
4545
)
4646
from pandas.core.dtypes.dtypes import (
47+
ArrowDtype,
4748
CategoricalDtype,
49+
CategoricalDtypeType,
4850
ExtensionDtype,
4951
)
5052
from pandas.core.dtypes.generic import (
@@ -443,24 +445,32 @@ def __init__(
443445
values = arr
444446

445447
if dtype.categories is None:
446-
if not isinstance(values, ABCIndex):
447-
# in particular RangeIndex xref test_index_equal_range_categories
448-
values = sanitize_array(values, None)
449-
try:
450-
codes, categories = factorize(values, sort=True)
451-
except TypeError as err:
452-
codes, categories = factorize(values, sort=False)
453-
if dtype.ordered:
454-
# raise, as we don't have a sortable data structure and so
455-
# the user should give us one by specifying categories
456-
raise TypeError(
457-
"'values' is not ordered, please "
458-
"explicitly specify the categories order "
459-
"by passing in a categories argument."
460-
) from err
461-
462-
# we're inferring from values
463-
dtype = CategoricalDtype(categories, dtype.ordered)
448+
if isinstance(values.dtype, ArrowDtype) and issubclass(
449+
values.dtype.type, CategoricalDtypeType
450+
):
451+
arr = values._pa_array.combine_chunks()
452+
categories = arr.dictionary.to_pandas(types_mapper=ArrowDtype)
453+
codes = arr.indices.to_numpy()
454+
dtype = CategoricalDtype(categories, values.dtype.pyarrow_dtype.ordered)
455+
else:
456+
if not isinstance(values, ABCIndex):
457+
# in particular RangeIndex xref test_index_equal_range_categories
458+
values = sanitize_array(values, None)
459+
try:
460+
codes, categories = factorize(values, sort=True)
461+
except TypeError as err:
462+
codes, categories = factorize(values, sort=False)
463+
if dtype.ordered:
464+
# raise, as we don't have a sortable data structure and so
465+
# the user should give us one by specifying categories
466+
raise TypeError(
467+
"'values' is not ordered, please "
468+
"explicitly specify the categories order "
469+
"by passing in a categories argument."
470+
) from err
471+
472+
# we're inferring from values
473+
dtype = CategoricalDtype(categories, dtype.ordered)
464474

465475
elif isinstance(values.dtype, CategoricalDtype):
466476
old_codes = extract_array(values)._codes

pandas/tests/extension/test_arrow.py

+16
Original file line numberDiff line numberDiff line change
@@ -3227,6 +3227,22 @@ def test_factorize_chunked_dictionary():
32273227
tm.assert_index_equal(res_uniques, exp_uniques)
32283228

32293229

3230+
def test_dictionary_astype_categorical():
3231+
# GH#56672
3232+
arrs = [
3233+
pa.array(np.array(["a", "x", "c", "a"])).dictionary_encode(),
3234+
pa.array(np.array(["a", "d", "c"])).dictionary_encode(),
3235+
]
3236+
ser = pd.Series(ArrowExtensionArray(pa.chunked_array(arrs)))
3237+
result = ser.astype("category")
3238+
categories = pd.Index(["a", "x", "c", "d"], dtype=ArrowDtype(pa.string()))
3239+
expected = pd.Series(
3240+
["a", "x", "c", "a", "a", "d", "c"],
3241+
dtype=pd.CategoricalDtype(categories=categories),
3242+
)
3243+
tm.assert_series_equal(result, expected)
3244+
3245+
32303246
def test_arrow_floordiv():
32313247
# GH 55561
32323248
a = pd.Series([-7], dtype="int64[pyarrow]")

0 commit comments

Comments
 (0)