|
6 | 6 | from shutil import get_terminal_size
|
7 | 7 | from typing import (
|
8 | 8 | TYPE_CHECKING,
|
9 |
| - Any, |
10 | 9 | Hashable,
|
11 | 10 | Sequence,
|
12 | 11 | TypeVar,
|
|
38 | 37 | Dtype,
|
39 | 38 | NpDtype,
|
40 | 39 | Ordered,
|
41 |
| - PositionalIndexer2D, |
42 |
| - PositionalIndexerTuple, |
43 |
| - ScalarIndexer, |
44 |
| - SequenceIndexer, |
45 | 40 | Shape,
|
46 | 41 | npt,
|
47 | 42 | type_t,
|
|
102 | 97 | take_nd,
|
103 | 98 | unique1d,
|
104 | 99 | )
|
105 |
| -from pandas.core.arrays._mixins import NDArrayBackedExtensionArray |
| 100 | +from pandas.core.arrays._mixins import ( |
| 101 | + NDArrayBackedExtensionArray, |
| 102 | + ravel_compat, |
| 103 | +) |
106 | 104 | from pandas.core.base import (
|
107 | 105 | ExtensionArray,
|
108 | 106 | NoNewAttributesMixin,
|
|
113 | 111 | extract_array,
|
114 | 112 | sanitize_array,
|
115 | 113 | )
|
116 |
| -from pandas.core.indexers import deprecate_ndim_indexing |
117 | 114 | from pandas.core.ops.common import unpack_zerodim_and_defer
|
118 | 115 | from pandas.core.sorting import nargsort
|
119 | 116 | from pandas.core.strings.object_array import ObjectStringArrayMixin
|
@@ -1479,6 +1476,7 @@ def _validate_scalar(self, fill_value):
|
1479 | 1476 |
|
1480 | 1477 | # -------------------------------------------------------------
|
1481 | 1478 |
|
| 1479 | + @ravel_compat |
1482 | 1480 | def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
|
1483 | 1481 | """
|
1484 | 1482 | The numpy array interface.
|
@@ -1929,7 +1927,10 @@ def __iter__(self):
|
1929 | 1927 | """
|
1930 | 1928 | Returns an Iterator over the values of this Categorical.
|
1931 | 1929 | """
|
1932 |
| - return iter(self._internal_get_values().tolist()) |
| 1930 | + if self.ndim == 1: |
| 1931 | + return iter(self._internal_get_values().tolist()) |
| 1932 | + else: |
| 1933 | + return (self[n] for n in range(len(self))) |
1933 | 1934 |
|
1934 | 1935 | def __contains__(self, key) -> bool:
|
1935 | 1936 | """
|
@@ -2048,27 +2049,6 @@ def __repr__(self) -> str:
|
2048 | 2049 |
|
2049 | 2050 | # ------------------------------------------------------------------
|
2050 | 2051 |
|
2051 |
| - @overload |
2052 |
| - def __getitem__(self, key: ScalarIndexer) -> Any: |
2053 |
| - ... |
2054 |
| - |
2055 |
| - @overload |
2056 |
| - def __getitem__( |
2057 |
| - self: CategoricalT, |
2058 |
| - key: SequenceIndexer | PositionalIndexerTuple, |
2059 |
| - ) -> CategoricalT: |
2060 |
| - ... |
2061 |
| - |
2062 |
| - def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any: |
2063 |
| - """ |
2064 |
| - Return an item. |
2065 |
| - """ |
2066 |
| - result = super().__getitem__(key) |
2067 |
| - if getattr(result, "ndim", 0) > 1: |
2068 |
| - result = result._ndarray |
2069 |
| - deprecate_ndim_indexing(result) |
2070 |
| - return result |
2071 |
| - |
2072 | 2052 | def _validate_listlike(self, value):
|
2073 | 2053 | # NB: here we assume scalar-like tuples have already been excluded
|
2074 | 2054 | value = extract_array(value, extract_numpy=True)
|
@@ -2306,7 +2286,19 @@ def _concat_same_type(
|
2306 | 2286 | ) -> CategoricalT:
|
2307 | 2287 | from pandas.core.dtypes.concat import union_categoricals
|
2308 | 2288 |
|
2309 |
| - return union_categoricals(to_concat) |
| 2289 | + result = union_categoricals(to_concat) |
| 2290 | + |
| 2291 | + # in case we are concatenating along axis != 0, we need to reshape |
| 2292 | + # the result from union_categoricals |
| 2293 | + first = to_concat[0] |
| 2294 | + if axis >= first.ndim: |
| 2295 | + raise ValueError |
| 2296 | + if axis == 1: |
| 2297 | + if not all(len(x) == len(first) for x in to_concat): |
| 2298 | + raise ValueError |
| 2299 | + # TODO: Will this get contiguity wrong? |
| 2300 | + result = result.reshape(-1, len(to_concat), order="F") |
| 2301 | + return result |
2310 | 2302 |
|
2311 | 2303 | # ------------------------------------------------------------------
|
2312 | 2304 |
|
@@ -2694,6 +2686,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
|
2694 | 2686 | """
|
2695 | 2687 | dtype_equal = is_dtype_equal(values.dtype, categories.dtype)
|
2696 | 2688 |
|
| 2689 | + if values.ndim > 1: |
| 2690 | + flat = values.ravel() |
| 2691 | + codes = _get_codes_for_values(flat, categories) |
| 2692 | + return codes.reshape(values.shape) |
| 2693 | + |
2697 | 2694 | if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values):
|
2698 | 2695 | # Support inferring the correct extension dtype from an array of
|
2699 | 2696 | # scalar objects. e.g.
|
|
0 commit comments