Skip to content

Commit 4f9bc8a

Browse files
authored
BUG: is_scalar_indexer (#32850)
* BUG: is_scalar_indexer * update docstring * add copy=False
1 parent 3b406a3 commit 4f9bc8a

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

pandas/core/indexers.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,26 @@ def is_list_like_indexer(key) -> bool:
6565
return is_list_like(key) and not (isinstance(key, tuple) and type(key) is not tuple)
6666

6767

68-
def is_scalar_indexer(indexer, arr_value) -> bool:
68+
def is_scalar_indexer(indexer, ndim: int) -> bool:
6969
"""
7070
Return True if we are all scalar indexers.
7171
72+
Parameters
73+
----------
74+
indexer : object
75+
ndim : int
76+
Number of dimensions in the object being indexed.
77+
7278
Returns
7379
-------
7480
bool
7581
"""
76-
if arr_value.ndim == 1:
77-
if not isinstance(indexer, tuple):
78-
indexer = tuple([indexer])
79-
return any(isinstance(idx, np.ndarray) and len(idx) == 0 for idx in indexer)
82+
if isinstance(indexer, tuple):
83+
if len(indexer) == ndim:
84+
return all(
85+
is_integer(x) or (isinstance(x, np.ndarray) and x.ndim == len(x) == 1)
86+
for x in indexer
87+
)
8088
return False
8189

8290

pandas/core/internals/blocks.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def setitem(self, indexer, value):
874874
# GH#8669 empty indexers
875875
pass
876876

877-
elif is_scalar_indexer(indexer, arr_value):
877+
elif is_scalar_indexer(indexer, self.ndim):
878878
# setting a single element for each dim and with a rhs that could
879879
# be e.g. a list; see GH#6043
880880
values[indexer] = value
@@ -892,12 +892,10 @@ def setitem(self, indexer, value):
892892
# if we are an exact match (ex-broadcasting),
893893
# then use the resultant dtype
894894
elif exact_match:
895+
# We are setting _all_ of the array's values, so can cast to new dtype
895896
values[indexer] = value
896897

897-
try:
898-
values = values.astype(arr_value.dtype)
899-
except ValueError:
900-
pass
898+
values = values.astype(arr_value.dtype, copy=False)
901899

902900
# set
903901
else:
+18-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,28 @@
11
# Tests aimed at pandas.core.indexers
22
import numpy as np
33

4-
from pandas.core.indexers import length_of_indexer
4+
from pandas.core.indexers import is_scalar_indexer, length_of_indexer
55

66

77
def test_length_of_indexer():
88
arr = np.zeros(4, dtype=bool)
99
arr[0] = 1
1010
result = length_of_indexer(arr)
1111
assert result == 1
12+
13+
14+
def test_is_scalar_indexer():
15+
indexer = (0, 1)
16+
assert is_scalar_indexer(indexer, 2)
17+
assert not is_scalar_indexer(indexer[0], 2)
18+
19+
indexer = (np.array([2]), 1)
20+
assert is_scalar_indexer(indexer, 2)
21+
22+
indexer = (np.array([2]), np.array([3]))
23+
assert is_scalar_indexer(indexer, 2)
24+
25+
indexer = (np.array([2]), np.array([3, 4]))
26+
assert not is_scalar_indexer(indexer, 2)
27+
28+
assert not is_scalar_indexer(slice(None), 1)

0 commit comments

Comments
 (0)