Skip to content

Commit 3cefbcc

Browse files
committed
changed arr to arr_lst
1 parent dfd154d commit 3cefbcc

File tree

1 file changed

+29
-37
lines changed

1 file changed

+29
-37
lines changed

pandas/core/arrays/categorical.py

+29-37
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from shutil import get_terminal_size
77
from typing import (
88
TYPE_CHECKING,
9-
Any,
109
Hashable,
1110
Sequence,
1211
TypeVar,
@@ -38,10 +37,6 @@
3837
Dtype,
3938
NpDtype,
4039
Ordered,
41-
PositionalIndexer2D,
42-
PositionalIndexerTuple,
43-
ScalarIndexer,
44-
SequenceIndexer,
4540
Shape,
4641
npt,
4742
type_t,
@@ -102,7 +97,10 @@
10297
take_nd,
10398
unique1d,
10499
)
105-
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
100+
from pandas.core.arrays._mixins import (
101+
NDArrayBackedExtensionArray,
102+
ravel_compat,
103+
)
106104
from pandas.core.base import (
107105
ExtensionArray,
108106
NoNewAttributesMixin,
@@ -113,7 +111,6 @@
113111
extract_array,
114112
sanitize_array,
115113
)
116-
from pandas.core.indexers import deprecate_ndim_indexing
117114
from pandas.core.ops.common import unpack_zerodim_and_defer
118115
from pandas.core.sorting import nargsort
119116
from pandas.core.strings.object_array import ObjectStringArrayMixin
@@ -424,13 +421,8 @@ def __init__(
424421
if null_mask.any():
425422
# We remove null values here, then below will re-insert
426423
# them, grep "full_codes"
427-
428-
# error: Incompatible types in assignment (expression has type
429-
# "List[Any]", variable has type "ExtensionArray")
430-
arr = [ # type: ignore[assignment]
431-
values[idx] for idx in np.where(~null_mask)[0]
432-
]
433-
arr = sanitize_array(arr, None)
424+
arr_lst = [values[idx] for idx in np.where(~null_mask)[0]]
425+
arr = sanitize_array(arr_lst, None)
434426
values = arr
435427

436428
if dtype.categories is None:
@@ -1484,6 +1476,7 @@ def _validate_scalar(self, fill_value):
14841476

14851477
# -------------------------------------------------------------
14861478

1479+
@ravel_compat
14871480
def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
14881481
"""
14891482
The numpy array interface.
@@ -1934,7 +1927,10 @@ def __iter__(self):
19341927
"""
19351928
Returns an Iterator over the values of this Categorical.
19361929
"""
1937-
return iter(self._internal_get_values().tolist())
1930+
if self.ndim == 1:
1931+
return iter(self._internal_get_values().tolist())
1932+
else:
1933+
return (self[n] for n in range(len(self)))
19381934

19391935
def __contains__(self, key) -> bool:
19401936
"""
@@ -2053,27 +2049,6 @@ def __repr__(self) -> str:
20532049

20542050
# ------------------------------------------------------------------
20552051

2056-
@overload
2057-
def __getitem__(self, key: ScalarIndexer) -> Any:
2058-
...
2059-
2060-
@overload
2061-
def __getitem__(
2062-
self: CategoricalT,
2063-
key: SequenceIndexer | PositionalIndexerTuple,
2064-
) -> CategoricalT:
2065-
...
2066-
2067-
def __getitem__(self: CategoricalT, key: PositionalIndexer2D) -> CategoricalT | Any:
2068-
"""
2069-
Return an item.
2070-
"""
2071-
result = super().__getitem__(key)
2072-
if getattr(result, "ndim", 0) > 1:
2073-
result = result._ndarray
2074-
deprecate_ndim_indexing(result)
2075-
return result
2076-
20772052
def _validate_listlike(self, value):
20782053
# NB: here we assume scalar-like tuples have already been excluded
20792054
value = extract_array(value, extract_numpy=True)
@@ -2311,7 +2286,19 @@ def _concat_same_type(
23112286
) -> CategoricalT:
23122287
from pandas.core.dtypes.concat import union_categoricals
23132288

2314-
return union_categoricals(to_concat)
2289+
result = union_categoricals(to_concat)
2290+
2291+
# in case we are concatenating along axis != 0, we need to reshape
2292+
# the result from union_categoricals
2293+
first = to_concat[0]
2294+
if axis >= first.ndim:
2295+
raise ValueError
2296+
if axis == 1:
2297+
if not all(len(x) == len(first) for x in to_concat):
2298+
raise ValueError
2299+
# TODO: Will this get contiguity wrong?
2300+
result = result.reshape(-1, len(to_concat), order="F")
2301+
return result
23152302

23162303
# ------------------------------------------------------------------
23172304

@@ -2699,6 +2686,11 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
26992686
"""
27002687
dtype_equal = is_dtype_equal(values.dtype, categories.dtype)
27012688

2689+
if values.ndim > 1:
2690+
flat = values.ravel()
2691+
codes = _get_codes_for_values(flat, categories)
2692+
return codes.reshape(values.shape)
2693+
27022694
if isinstance(categories.dtype, ExtensionDtype) and is_object_dtype(values):
27032695
# Support inferring the correct extension dtype from an array of
27042696
# scalar objects. e.g.

0 commit comments

Comments
 (0)