Skip to content

Commit 405b5c5

Browse files
authored
REF: implement Categorical._validate_setitem_value (#36180)
1 parent c1d7bbd commit 405b5c5

File tree

2 files changed

+28
-23
lines changed

2 files changed

+28
-23
lines changed

pandas/core/arrays/categorical.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pandas._config import get_option
1111

12-
from pandas._libs import NaT, algos as libalgos, hashtable as htable
12+
from pandas._libs import NaT, algos as libalgos, hashtable as htable, lib
1313
from pandas._typing import ArrayLike, Dtype, Ordered, Scalar
1414
from pandas.compat.numpy import function as nv
1515
from pandas.util._decorators import cache_readonly, deprecate_kwarg, doc
@@ -1868,14 +1868,6 @@ def __repr__(self) -> str:
18681868

18691869
# ------------------------------------------------------------------
18701870

1871-
def _maybe_coerce_indexer(self, indexer):
1872-
"""
1873-
return an indexer coerced to the codes dtype
1874-
"""
1875-
if isinstance(indexer, np.ndarray) and indexer.dtype.kind == "i":
1876-
indexer = indexer.astype(self._codes.dtype)
1877-
return indexer
1878-
18791871
def __getitem__(self, key):
18801872
"""
18811873
Return an item.
@@ -1905,6 +1897,11 @@ def __setitem__(self, key, value):
19051897
If (one or more) Value is not in categories or if a assigned
19061898
`Categorical` does not have the same categories
19071899
"""
1900+
key = self._validate_setitem_key(key)
1901+
value = self._validate_setitem_value(value)
1902+
self._ndarray[key] = value
1903+
1904+
def _validate_setitem_value(self, value):
19081905
value = extract_array(value, extract_numpy=True)
19091906

19101907
# require identical categories set
@@ -1934,12 +1931,19 @@ def __setitem__(self, key, value):
19341931
"category, set the categories first"
19351932
)
19361933

1937-
# set by position
1938-
if isinstance(key, (int, np.integer)):
1934+
lindexer = self.categories.get_indexer(rvalue)
1935+
if isinstance(lindexer, np.ndarray) and lindexer.dtype.kind == "i":
1936+
lindexer = lindexer.astype(self._ndarray.dtype)
1937+
1938+
return lindexer
1939+
1940+
def _validate_setitem_key(self, key):
1941+
if lib.is_integer(key):
1942+
# set by position
19391943
pass
19401944

1941-
# tuple of indexers (dataframe)
19421945
elif isinstance(key, tuple):
1946+
# tuple of indexers (dataframe)
19431947
# only allow 1 dimensional slicing, but can
19441948
# in a 2-d case be passed (slice(None),....)
19451949
if len(key) == 2:
@@ -1951,17 +1955,14 @@ def __setitem__(self, key, value):
19511955
else:
19521956
raise AssertionError("invalid slicing for a 1-ndim categorical")
19531957

1954-
# slicing in Series or Categorical
19551958
elif isinstance(key, slice):
1959+
# slicing in Series or Categorical
19561960
pass
19571961

19581962
# else: array of True/False in Series or Categorical
19591963

1960-
lindexer = self.categories.get_indexer(rvalue)
1961-
lindexer = self._maybe_coerce_indexer(lindexer)
1962-
19631964
key = check_array_indexer(self, key)
1964-
self._codes[key] = lindexer
1965+
return key
19651966

19661967
def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:
19671968
"""

pandas/core/arrays/datetimelike.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,15 @@ def __getitem__(self, key):
546546
return self._box_func(result)
547547
return self._simple_new(result, dtype=self.dtype)
548548

549+
key = self._validate_getitem_key(key)
550+
result = self._ndarray[key]
551+
if lib.is_scalar(result):
552+
return self._box_func(result)
553+
554+
freq = self._get_getitem_freq(key)
555+
return self._simple_new(result, dtype=self.dtype, freq=freq)
556+
557+
def _validate_getitem_key(self, key):
549558
if com.is_bool_indexer(key):
550559
# first convert to boolean, because check_array_indexer doesn't
551560
# allow object dtype
@@ -560,12 +569,7 @@ def __getitem__(self, key):
560569
pass
561570
else:
562571
key = check_array_indexer(self, key)
563-
564-
freq = self._get_getitem_freq(key)
565-
result = self._ndarray[key]
566-
if lib.is_scalar(result):
567-
return self._box_func(result)
568-
return self._simple_new(result, dtype=self.dtype, freq=freq)
572+
return key
569573

570574
def _get_getitem_freq(self, key):
571575
"""

0 commit comments

Comments
 (0)