Skip to content

Commit c449afd

Browse files
committed
Bounds checking
1 parent 449983b commit c449afd

File tree

7 files changed

+132
-20
lines changed

7 files changed

+132
-20
lines changed

pandas/core/algorithms.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -1488,24 +1488,16 @@ def take(arr, indexer, allow_fill=False, fill_value=None):
14881488
--------
14891489
numpy.take
14901490
"""
1491-
indexer = np.asarray(indexer)
1491+
from pandas.core.indexing import validate_indices
1492+
1493+
# Do we require int64 here?
1494+
indexer = np.asarray(indexer, dtype='int')
14921495

14931496
if allow_fill:
14941497
# Pandas style, -1 means NA
1495-
# bounds checking
1496-
if (indexer < -1).any():
1497-
raise ValueError("Invalid value in 'indexer'. All values "
1498-
"must be non-negative or -1. When "
1499-
"'fill_value' is specified.")
1500-
if (indexer > len(arr)).any():
1501-
# TODO: reuse with logic elsewhere.
1502-
raise IndexError
1503-
1504-
# # take on empty array not handled as desired by numpy
1505-
# # in case of -1 (all missing take)
1506-
# if not len(arr) and mask.all():
1507-
# return arr._from_sequence([fill_value] * len(indexer))
1508-
result = take_1d(arr, indexer, fill_value=fill_value)
1498+
# Use for bounds checking, we don't actually want to convert.
1499+
validate_indices(indexer, len(arr))
1500+
result = take_1d(arr, indexer, allow_fill=True, fill_value=fill_value)
15091501
else:
15101502
# NumPy style
15111503
result = arr.take(indexer)

pandas/core/indexing.py

+41
Original file line numberDiff line numberDiff line change
@@ -2417,12 +2417,53 @@ def maybe_convert_indices(indices, n):
24172417
mask = indices < 0
24182418
if mask.any():
24192419
indices[mask] += n
2420+
24202421
mask = (indices >= n) | (indices < 0)
24212422
if mask.any():
24222423
raise IndexError("indices are out-of-bounds")
24232424
return indices
24242425

24252426

2427+
def validate_indices(indices, n):
2428+
"""Perform bounds-checking for an indexer.
2429+
2430+
-1 is allowed for indicating missing values.
2431+
2432+
Parameters
2433+
----------
2434+
indices : ndarray
2435+
n : int
2436+
length of the array being indexed
2437+
2438+
Raises
2439+
------
2440+
ValueError
2441+
2442+
Examples
2443+
--------
2444+
>>> validate_indices([1, 2], 3)
2445+
# OK
2446+
>>> validate_indices([1, -2], 3)
2447+
ValueError
2448+
>>> validate_indices([1, 2, 3], 3)
2449+
IndexError
2450+
>>> validate_indices([-1, -1], 0)
2451+
# OK
2452+
>>> validate_indices([0, 1], 0)
2453+
IndexError
2454+
"""
2455+
if len(indices):
2456+
min_idx = indices.min()
2457+
if min_idx < -1:
2458+
msg = ("'indices' contains values less than allowed ({} < {})"
2459+
.format(min_idx, -1))
2460+
raise ValueError(msg)
2461+
2462+
max_idx = indices.max()
2463+
if max_idx >= n:
2464+
raise IndexError("indices are out-of-bounds")
2465+
2466+
24262467
def maybe_convert_ix(*args):
24272468
"""
24282469
We likely want to take the cross-product

pandas/tests/extension/base/getitem.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,12 @@ def test_take(self, data, na_value, na_cmp):
134134

135135
def test_take_empty(self, data, na_value, na_cmp):
136136
empty = data[:0]
137-
# result = empty.take([-1])
138-
# na_cmp(result[0], na_value)
137+
138+
result = empty.take([-1], allow_fill=True)
139+
na_cmp(result[0], na_value)
140+
141+
with pytest.raises(IndexError):
142+
empty.take([-1])
139143

140144
with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"):
141145
empty.take([0, 1])
@@ -160,6 +164,12 @@ def test_take_pandas_style_negative_raises(self, data, na_value):
160164
with pytest.raises(ValueError):
161165
data.take([0, -2], fill_value=na_value, allow_fill=True)
162166

167+
@pytest.mark.parametrize('allow_fill', [True, False])
168+
def test_take_out_of_bounds_raises(self, data, allow_fill):
169+
arr = data[:3]
170+
with pytest.raises(IndexError):
171+
arr.take(np.asarray([0, 3]), allow_fill=allow_fill)
172+
163173
@pytest.mark.xfail(reason="Series.take with extension array buggy for -1")
164174
def test_take_series(self, data):
165175
s = pd.Series(data)

pandas/tests/extension/category/test_categorical.py

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def test_take_pandas_style_negative_raises(self):
107107
def test_take_non_na_fill_value(self):
108108
pass
109109

110+
@skip_take
111+
def test_take_out_of_bounds_raises(self):
112+
pass
113+
110114
@pytest.mark.xfail(reason="Categorical.take buggy")
111115
def test_take_empty(self):
112116
pass

pandas/tests/extension/json/array.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
class JSONDtype(ExtensionDtype):
1515
type = collections.Mapping
1616
name = 'json'
17-
na_value = collections.UserDict()
17+
try:
18+
na_value = collections.UserDict()
19+
except AttributeError:
20+
# source compatibility with Py2.
21+
na_value = {}
1822

1923
@classmethod
2024
def construct_from_string(cls, string):
@@ -112,7 +116,7 @@ def take(self, indexer, allow_fill=False, fill_value=None):
112116
output = [self.data[loc] if loc != -1 else fill_value
113117
for loc in indexer]
114118
except IndexError:
115-
raise msg
119+
raise IndexError(msg)
116120
else:
117121
try:
118122
output = [self.data[loc] for loc in indexer]

pandas/tests/indexing/test_indexing.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import numpy as np
1717

1818
import pandas as pd
19-
from pandas.core.indexing import _non_reducing_slice, _maybe_numeric_slice
19+
from pandas.core.indexing import (_non_reducing_slice, _maybe_numeric_slice,
20+
validate_indices)
2021
from pandas import NaT, DataFrame, Index, Series, MultiIndex
2122
import pandas.util.testing as tm
2223

@@ -994,3 +995,27 @@ def test_none_coercion_mixed_dtypes(self):
994995
datetime(2000, 1, 3)],
995996
'd': [None, 'b', 'c']})
996997
tm.assert_frame_equal(start_dataframe, exp)
998+
999+
1000+
def test_validate_indices_ok():
1001+
indices = np.asarray([0, 1])
1002+
validate_indices(indices, 2)
1003+
validate_indices(indices[:0], 0)
1004+
validate_indices(np.array([-1, -1]), 0)
1005+
1006+
1007+
def test_validate_indices_low():
1008+
indices = np.asarray([0, -2])
1009+
with tm.assert_raises_regex(ValueError, "'indices' contains"):
1010+
validate_indices(indices, 2)
1011+
1012+
1013+
def test_validate_indices_high():
1014+
indices = np.asarray([0, 1, 2])
1015+
with tm.assert_raises_regex(IndexError, "indices are out"):
1016+
validate_indices(indices, 2)
1017+
1018+
1019+
def test_validate_indices_empty():
1020+
with tm.assert_raises_regex(IndexError, "indices are out"):
1021+
validate_indices(np.array([0, 1]), 0)

pandas/tests/test_algos.py

+36
Original file line numberDiff line numberDiff line change
@@ -1564,3 +1564,39 @@ def test_index(self):
15641564
idx = Index(['1 day', '1 day', '-1 day', '-1 day 2 min',
15651565
'2 min', '2 min'], dtype='timedelta64[ns]')
15661566
tm.assert_series_equal(algos.mode(idx), exp)
1567+
1568+
1569+
class TestTake(object):
1570+
1571+
def test_bounds_check_large(self):
1572+
arr = np.array([1, 2])
1573+
with pytest.raises(IndexError):
1574+
algos.take(arr, [2, 3], allow_fill=True)
1575+
1576+
with pytest.raises(IndexError):
1577+
algos.take(arr, [2, 3], allow_fill=False)
1578+
1579+
def test_bounds_check_small(self):
1580+
arr = np.array([1, 2, 3], dtype=np.int64)
1581+
indexer = [0, -1, -2]
1582+
with pytest.raises(ValueError):
1583+
algos.take(arr, indexer, allow_fill=True)
1584+
1585+
result = algos.take(arr, indexer)
1586+
expected = np.array([1, 3, 2], dtype=np.int64)
1587+
tm.assert_numpy_array_equal(result, expected)
1588+
1589+
@pytest.mark.parametrize('allow_fill', [True, False])
1590+
def test_take_empty(self, allow_fill):
1591+
arr = np.array([], dtype=np.int64)
1592+
# empty take is ok
1593+
result = algos.take(arr, [], allow_fill=allow_fill)
1594+
tm.assert_numpy_array_equal(arr, result)
1595+
1596+
with pytest.raises(IndexError):
1597+
algos.take(arr, [0], allow_fill=allow_fill)
1598+
1599+
def test_take_na_empty(self):
1600+
result = algos.take([], [-1, -1], allow_fill=True, fill_value=0)
1601+
expected = np.array([0, 0], dtype=np.int64)
1602+
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)