Skip to content

Commit 2baa64e

Browse files
committed
Don't tupleize categories
1 parent 6f137b3 commit 2baa64e

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

pandas/core/dtypes/dtypes.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class CategoricalDtype(ExtensionDtype):
146146
def __new__(cls, categories=None, ordered=False, fastpath=False):
147147
from pandas.core.indexes.base import Index
148148
if categories is not None:
149-
categories = Index(categories)
149+
categories = Index(categories, tupleize_cols=False)
150150
# validation
151151
cls._validate_categories(categories, fastpath=fastpath)
152152
cls._validate_ordered(ordered)
@@ -211,8 +211,18 @@ def __repr__(self):
211211

212212
@staticmethod
213213
def _hash_categories(categories, ordered=True):
214-
from pandas.core.util.hashing import hash_array, _combine_hash_arrays
215-
cat_array = hash_array(np.asarray(categories), categorize=False)
214+
from pandas.core.util.hashing import (
215+
hash_array, _combine_hash_arrays, hash_tuples
216+
)
217+
218+
categories = np.asarray(categories)
219+
if len(categories) and isinstance(categories[0], tuple):
220+
# assumes if any individual category is a tuple, then all our. ATM
221+
# I don't really want to support just some of the categories being
222+
# tuples.
223+
cat_array = hash_tuples(categories)
224+
else:
225+
cat_array = hash_array(np.asarray(categories), categorize=False)
216226
if ordered:
217227
cat_array = np.vstack([
218228
cat_array, np.arange(len(cat_array), dtype=cat_array.dtype)

pandas/tests/dtypes/test_dtypes.py

+5
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ def test_basic(self):
119119
assert not is_categorical(np.dtype('float64'))
120120
assert not is_categorical(1.0)
121121

122+
def test_tuple_categories(self):
123+
categories = [(1, 'a'), (2, 'b'), (3, 'c')]
124+
result = CategoricalDtype(categories)
125+
assert all(result.categories == categories)
126+
122127

123128
class TestDatetimeTZDtype(Base):
124129

0 commit comments

Comments
 (0)