Skip to content

REF: implement Categorical._validate_setitem_value #36180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pandas._config import get_option

from pandas._libs import NaT, algos as libalgos, hashtable as htable
from pandas._libs import NaT, algos as libalgos, hashtable as htable, lib
from pandas._typing import ArrayLike, Dtype, Ordered, Scalar
from pandas.compat.numpy import function as nv
from pandas.util._decorators import cache_readonly, deprecate_kwarg, doc
Expand Down Expand Up @@ -1868,14 +1868,6 @@ def __repr__(self) -> str:

# ------------------------------------------------------------------

def _maybe_coerce_indexer(self, indexer):
"""
return an indexer coerced to the codes dtype
"""
if isinstance(indexer, np.ndarray) and indexer.dtype.kind == "i":
indexer = indexer.astype(self._codes.dtype)
return indexer

def __getitem__(self, key):
"""
Return an item.
Expand Down Expand Up @@ -1905,6 +1897,11 @@ def __setitem__(self, key, value):
If (one or more) Value is not in categories or if a assigned
`Categorical` does not have the same categories
"""
key = self._validate_setitem_key(key)
value = self._validate_setitem_value(value)
self._ndarray[key] = value

def _validate_setitem_value(self, value):
value = extract_array(value, extract_numpy=True)

# require identical categories set
Expand Down Expand Up @@ -1934,12 +1931,19 @@ def __setitem__(self, key, value):
"category, set the categories first"
)

# set by position
if isinstance(key, (int, np.integer)):
lindexer = self.categories.get_indexer(rvalue)
if isinstance(lindexer, np.ndarray) and lindexer.dtype.kind == "i":
lindexer = lindexer.astype(self._ndarray.dtype)

return lindexer

def _validate_setitem_key(self, key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you type

if lib.is_integer(key):
# set by position
pass

# tuple of indexers (dataframe)
elif isinstance(key, tuple):
# tuple of indexers (dataframe)
# only allow 1 dimensional slicing, but can
# in a 2-d case be passed (slice(None),....)
if len(key) == 2:
Expand All @@ -1951,17 +1955,14 @@ def __setitem__(self, key, value):
else:
raise AssertionError("invalid slicing for a 1-ndim categorical")

# slicing in Series or Categorical
elif isinstance(key, slice):
# slicing in Series or Categorical
pass

# else: array of True/False in Series or Categorical

lindexer = self.categories.get_indexer(rvalue)
lindexer = self._maybe_coerce_indexer(lindexer)

key = check_array_indexer(self, key)
self._codes[key] = lindexer
return key

def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]:
"""
Expand Down
16 changes: 10 additions & 6 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def __getitem__(self, key):
return self._box_func(result)
return self._simple_new(result, dtype=self.dtype)

key = self._validate_getitem_key(key)
result = self._ndarray[key]
if lib.is_scalar(result):
return self._box_func(result)

freq = self._get_getitem_freq(key)
return self._simple_new(result, dtype=self.dtype, freq=freq)

def _validate_getitem_key(self, key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc-string if easy

if com.is_bool_indexer(key):
# first convert to boolean, because check_array_indexer doesn't
# allow object dtype
Expand All @@ -560,12 +569,7 @@ def __getitem__(self, key):
pass
else:
key = check_array_indexer(self, key)

freq = self._get_getitem_freq(key)
result = self._ndarray[key]
if lib.is_scalar(result):
return self._box_func(result)
return self._simple_new(result, dtype=self.dtype, freq=freq)
return key

def _get_getitem_freq(self, key):
"""
Expand Down