Skip to content

Commit c3cefa7

Browse files
jbrockmendelnickleus27
authored andcommitted
TST: enable 2D tests for Categorical (pandas-dev#44206)
1 parent 2d9bddf commit c3cefa7

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

pandas/core/arrays/categorical.py

+27-30
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from shutil import get_terminal_size
77
from typing import (
88
TYPE_CHECKING,
9-
Any,
109
Hashable,
1110
Sequence,
1211
TypeVar,
@@ -38,10 +37,6 @@
3837
Dtype,
3938
NpDtype,
4039
Ordered,
41-
PositionalIndexer2D,
42-
PositionalIndexerTuple,
43-
ScalarIndexer,
44-
SequenceIndexer,
4540
Shape,
4641
npt,
4742
type_t,
@@ -102,7 +97,10 @@
10297
take_nd,
10398
unique1d,
10499
)
105-
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
100+
from pandas.core.arrays._mixins import (
101+
NDArrayBackedExtensionArray,
102+
ravel_compat,
103+
)
106104
from pandas.core.base import (
107105
ExtensionArray,
108106
NoNewAttributesMixin,
@@ -113,7 +111,6 @@
113111
extract_array,
114112
sanitize_array,
115113
)
116-
from pandas.core.indexers import deprecate_ndim_indexing
117114
from pandas.core.ops.common import unpack_zerodim_and_defer
118115
from pandas.core.sorting import nargsort
119116
from pandas.core.strings.object_array import ObjectStringArrayMixin
@@ -1479,6 +1476,7 @@ def _validate_scalar(self, fill_value):
14791476

14801477
# -------------------------------------------------------------
14811478

1479+
@ravel_compat
14821480
def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
14831481
"""
14841482
The numpy array interface.
@@ -1929,7 +1927,10 @@ def __iter__(self):
19291927
"""
19301928
Returns an Iterator over the values of this Categorical.
19311929
"""
1932-
return iter(self._internal_get_values().tolist())
1930+
if self.ndim == 1:
1931+
return iter(self._internal_get_values().tolist())
1932+
else:
1933+
return (self[n] for n in range(len(self)))
19331934

19341935
def __contains__(self, key) -> bool:
19351936
"""
@@ -2048,27 +2049,6 @@ def __repr__(self) -> str:
20482049

20492050
# ------------------------------------------------------------------
20502051

2051-
@overload
2052-
def __getitem__(self, key: ScalarIndexer) -> Any:
2053-
...
2054-
2055-
@overload
2056-
def __getitem__(
2057-
self: CategoricalT,
2058-
key: SequenceIndexer | PositionalIndexerTuple,
2059-
) -> CategoricalT:
2060-
...
2061-
2062-
def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any:
2063-
"""
2064-
Return an item.
2065-
"""
2066-
result = super().__getitem__(key)
2067-
if getattr(result, "ndim", 0) > 1:
2068-
result = result._ndarray
2069-
deprecate_ndim_indexing(result)
2070-
return result
2071-
20722052
def _validate_listlike(self, value):
20732053
# NB: here we assume scalar-like tuples have already been excluded
20742054
value = extract_array(value, extract_numpy=True)
@@ -2306,7 +2286,19 @@ def _concat_same_type(
23062286
) -> CategoricalT:
23072287
from pandas.core.dtypes.concat import union_categoricals
23082288

2309-
return union_categoricals(to_concat)
2289+
result = union_categoricals(to_concat)
2290+
2291+
# in case we are concatenating along axis != 0, we need to reshape
2292+
# the result from union_categoricals
2293+
first = to_concat[0]
2294+
if axis >= first.ndim:
2295+
raise ValueError
2296+
if axis == 1:
2297+
if not all(len(x) == len(first) for x in to_concat):
2298+
raise ValueError
2299+
# TODO: Will this get contiguity wrong?
2300+
result = result.reshape(-1, len(to_concat), order="F")
2301+
return result
23102302

23112303
# ------------------------------------------------------------------
23122304

@@ -2694,6 +2686,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
26942686
"""
26952687
dtype_equal = is_dtype_equal(values.dtype, categories.dtype)
26962688

2689+
if values.ndim > 1:
2690+
flat = values.ravel()
2691+
codes = _get_codes_for_values(flat, categories)
2692+
return codes.reshape(values.shape)
2693+
26972694
if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values):
26982695
# Support inferring the correct extension dtype from an array of
26992696
# scalar objects. e.g.

pandas/tests/extension/test_categorical.py

+11
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,14 @@ def test_not_equal_with_na(self, categories):
303303

304304
class TestParsing(base.BaseParsingTests):
305305
pass
306+
307+
308+
class Test2DCompat(base.Dim2CompatTests):
309+
def test_repr_2d(self, data):
310+
# Categorical __repr__ doesn't include "Categorical", so we need
311+
# to special-case
312+
res = repr(data.reshape(1, -1))
313+
assert res.count("\nCategories") == 1
314+
315+
res = repr(data.reshape(-1, 1))
316+
assert res.count("\nCategories") == 1

0 commit comments

Comments
 (0)