|
6 | 6 | from shutil import get_terminal_size
|
7 | 7 | from typing import (
|
8 | 8 | TYPE_CHECKING,
|
9 |
| - Any, |
10 | 9 | Hashable,
|
11 | 10 | Sequence,
|
12 | 11 | TypeVar,
|
|
38 | 37 | Dtype,
|
39 | 38 | NpDtype,
|
40 | 39 | Ordered,
|
41 |
| - PositionalIndexer2D, |
42 |
| - PositionalIndexerTuple, |
43 |
| - ScalarIndexer, |
44 |
| - SequenceIndexer, |
45 | 40 | Shape,
|
46 | 41 | npt,
|
47 | 42 | type_t,
|
|
102 | 97 | take_nd,
|
103 | 98 | unique1d,
|
104 | 99 | )
|
105 |
| -from pandas.core.arrays._mixins import NDArrayBackedExtensionArray |
| 100 | +from pandas.core.arrays._mixins import ( |
| 101 | + NDArrayBackedExtensionArray, |
| 102 | + ravel_compat, |
| 103 | +) |
106 | 104 | from pandas.core.base import (
|
107 | 105 | ExtensionArray,
|
108 | 106 | NoNewAttributesMixin,
|
|
113 | 111 | extract_array,
|
114 | 112 | sanitize_array,
|
115 | 113 | )
|
116 |
| -from pandas.core.indexers import deprecate_ndim_indexing |
117 | 114 | from pandas.core.ops.common import unpack_zerodim_and_defer
|
118 | 115 | from pandas.core.sorting import nargsort
|
119 | 116 | from pandas.core.strings.object_array import ObjectStringArrayMixin
|
@@ -424,13 +421,8 @@ def __init__(
|
424 | 421 | if null_mask.any():
|
425 | 422 | # We remove null values here, then below will re-insert
|
426 | 423 | # them, grep "full_codes"
|
427 |
| - |
428 |
| - # error: Incompatible types in assignment (expression has type |
429 |
| - # "List[Any]", variable has type "ExtensionArray") |
430 |
| - arr = [ # type: ignore[assignment] |
431 |
| - values[idx] for idx in np.where(~null_mask)[0] |
432 |
| - ] |
433 |
| - arr = sanitize_array(arr, None) |
| 424 | + arr_lst = [values[idx] for idx in np.where(~null_mask)[0]] |
| 425 | + arr = sanitize_array(arr_lst, None) |
434 | 426 | values = arr
|
435 | 427 |
|
436 | 428 | if dtype.categories is None:
|
@@ -1484,6 +1476,7 @@ def _validate_scalar(self, fill_value):
|
1484 | 1476 |
|
1485 | 1477 | # -------------------------------------------------------------
|
1486 | 1478 |
|
| 1479 | + @ravel_compat |
1487 | 1480 | def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
|
1488 | 1481 | """
|
1489 | 1482 | The numpy array interface.
|
@@ -1934,7 +1927,10 @@ def __iter__(self):
|
1934 | 1927 | """
|
1935 | 1928 | Returns an Iterator over the values of this Categorical.
|
1936 | 1929 | """
|
1937 |
| - return iter(self._internal_get_values().tolist()) |
| 1930 | + if self.ndim == 1: |
| 1931 | + return iter(self._internal_get_values().tolist()) |
| 1932 | + else: |
| 1933 | + return (self[n] for n in range(len(self))) |
1938 | 1934 |
|
1939 | 1935 | def __contains__(self, key) -> bool:
|
1940 | 1936 | """
|
@@ -2053,27 +2049,6 @@ def __repr__(self) -> str:
|
2053 | 2049 |
|
2054 | 2050 | # ------------------------------------------------------------------
|
2055 | 2051 |
|
2056 |
| - @overload |
2057 |
| - def __getitem__(self, key: ScalarIndexer) -> Any: |
2058 |
| - ... |
2059 |
| - |
2060 |
| - @overload |
2061 |
| - def __getitem__( |
2062 |
| - self: CategoricalT, |
2063 |
| - key: SequenceIndexer | PositionalIndexerTuple, |
2064 |
| - ) -> CategoricalT: |
2065 |
| - ... |
2066 |
| - |
2067 |
| - def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any: |
2068 |
| - """ |
2069 |
| - Return an item. |
2070 |
| - """ |
2071 |
| - result = super().__getitem__(key) |
2072 |
| - if getattr(result, "ndim", 0) > 1: |
2073 |
| - result = result._ndarray |
2074 |
| - deprecate_ndim_indexing(result) |
2075 |
| - return result |
2076 |
| - |
2077 | 2052 | def _validate_listlike(self, value):
|
2078 | 2053 | # NB: here we assume scalar-like tuples have already been excluded
|
2079 | 2054 | value = extract_array(value, extract_numpy=True)
|
@@ -2311,7 +2286,19 @@ def _concat_same_type(
|
2311 | 2286 | ) -> CategoricalT:
|
2312 | 2287 | from pandas.core.dtypes.concat import union_categoricals
|
2313 | 2288 |
|
2314 |
| - return union_categoricals(to_concat) |
| 2289 | + result = union_categoricals(to_concat) |
| 2290 | + |
| 2291 | + # in case we are concatenating along axis != 0, we need to reshape |
| 2292 | + # the result from union_categoricals |
| 2293 | + first = to_concat[0] |
| 2294 | + if axis >= first.ndim: |
| 2295 | + raise ValueError |
| 2296 | + if axis == 1: |
| 2297 | + if not all(len(x) == len(first) for x in to_concat): |
| 2298 | + raise ValueError |
| 2299 | + # TODO: Will this get contiguity wrong? |
| 2300 | + result = result.reshape(-1, len(to_concat), order="F") |
| 2301 | + return result |
2315 | 2302 |
|
2316 | 2303 | # ------------------------------------------------------------------
|
2317 | 2304 |
|
@@ -2699,6 +2686,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
|
2699 | 2686 | """
|
2700 | 2687 | dtype_equal = is_dtype_equal(values.dtype, categories.dtype)
|
2701 | 2688 |
|
| 2689 | + if values.ndim > 1: |
| 2690 | + flat = values.ravel() |
| 2691 | + codes = _get_codes_for_values(flat, categories) |
| 2692 | + return codes.reshape(values.shape) |
| 2693 | + |
2702 | 2694 | if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values):
|
2703 | 2695 | # Support inferring the correct extension dtype from an array of
|
2704 | 2696 | # scalar objects. e.g.
|
|
0 commit comments