diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index bbb3cb3391dfa..074c9ca7b483c 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -1239,8 +1239,7 @@ def _convert_to_indexer(self, key, axis: int): if com.is_bool_indexer(key): key = check_bool_indexer(labels, key) - (inds,) = key.nonzero() - return inds + return key else: return self._get_listlike_indexer(key, axis)[1] else: @@ -1804,6 +1803,7 @@ def _setitem_single_column(self, loc: int, value, plane_indexer): The indexer we use for setitem along axis=0. """ pi = plane_indexer + pi, value = mask_setitem_value(pi, value, len(self.obj)) ser = self.obj._ixs(loc, axis=1) @@ -2395,3 +2395,27 @@ def need_slice(obj: slice) -> bool: or obj.stop is not None or (obj.step is not None and obj.step != 1) ) + + +def mask_setitem_value(indexer, value, length: int): + """ + Convert a boolean indexer to a positional indexer, masking `value` if necessary. + """ + if com.is_bool_indexer(indexer): + indexer = np.asarray(indexer).nonzero()[0] + if is_list_like(value) and len(value) == length: + if not is_array_like(value): + value = [value[n] for n in indexer] + else: + value = value[indexer] + + elif isinstance(indexer, tuple): + indexer = list(indexer) + for i, key in enumerate(indexer): + if com.is_bool_indexer(key): + new_key = np.asarray(key).nonzero()[0] + indexer[i] = new_key + # TODO: sometimes need to do take on the value? + + indexer = tuple(indexer) + return indexer, value diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index da7ffbf08c34b..1d6639b062e1d 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -112,6 +112,7 @@ is_exact_shape_match, is_scalar_indexer, ) +from pandas.core.indexing import mask_setitem_value import pandas.core.missing as missing if TYPE_CHECKING: @@ -936,6 +937,8 @@ def setitem(self, indexer, value): if transpose: values = values.T + indexer, value = mask_setitem_value(indexer, value, len(values)) + # length checking check_setitem_lengths(indexer, value, values) exact_match = is_exact_shape_match(values, arr_value)