Skip to content

Commit 25384ba

Browse files
topper-123jreback
authored andcommitted
Let _get_dtype accept Categoricals and CategoricalIndex (#16887)
1 parent 63536f4 commit 25384ba

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

doc/source/whatsnew/v0.21.0.txt

-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ Conversion
149149
^^^^^^^^^^
150150

151151

152-
153152
Indexing
154153
^^^^^^^^
155154

pandas/core/dtypes/common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ExtensionDtype)
1212
from .generic import (ABCCategorical, ABCPeriodIndex,
1313
ABCDatetimeIndex, ABCSeries,
14-
ABCSparseArray, ABCSparseSeries)
14+
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex)
1515
from .inference import is_string_like
1616
from .inference import * # noqa
1717

@@ -1713,6 +1713,8 @@ def _get_dtype(arr_or_dtype):
17131713
return PeriodDtype.construct_from_string(arr_or_dtype)
17141714
elif is_interval_dtype(arr_or_dtype):
17151715
return IntervalDtype.construct_from_string(arr_or_dtype)
1716+
elif isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex)):
1717+
return arr_or_dtype.dtype
17161718

17171719
if hasattr(arr_or_dtype, 'dtype'):
17181720
arr_or_dtype = arr_or_dtype.dtype

pandas/tests/dtypes/test_common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -532,16 +532,16 @@ def test_is_complex_dtype():
532532
(float, np.dtype(float)),
533533
('float64', np.dtype('float64')),
534534
(np.dtype('float64'), np.dtype('float64')),
535-
pytest.mark.xfail((str, np.dtype('<U')), ),
535+
(str, np.dtype(str)),
536536
(pd.Series([1, 2], dtype=np.dtype('int16')), np.dtype('int16')),
537537
(pd.Series(['a', 'b']), np.dtype(object)),
538538
(pd.Index([1, 2]), np.dtype('int64')),
539539
(pd.Index(['a', 'b']), np.dtype(object)),
540540
('category', 'category'),
541541
(pd.Categorical(['a', 'b']).dtype, CategoricalDtype()),
542-
pytest.mark.xfail((pd.Categorical(['a', 'b']), CategoricalDtype()),),
542+
(pd.Categorical(['a', 'b']), CategoricalDtype()),
543543
(pd.CategoricalIndex(['a', 'b']).dtype, CategoricalDtype()),
544-
pytest.mark.xfail((pd.CategoricalIndex(['a', 'b']), CategoricalDtype()),),
544+
(pd.CategoricalIndex(['a', 'b']), CategoricalDtype()),
545545
(pd.DatetimeIndex([1, 2]), np.dtype('<M8[ns]')),
546546
(pd.DatetimeIndex([1, 2]).dtype, np.dtype('<M8[ns]')),
547547
('<M8[ns]', np.dtype('<M8[ns]')),

0 commit comments

Comments
 (0)