Skip to content

Commit a975566

Browse files
h-vetinariPingviinituutti
authored andcommitted
API/ERR: allow iterators in df.set_index & improve errors (pandas-dev#24984)
1 parent 233ca55 commit a975566

File tree

4 files changed

+79
-11
lines changed

4 files changed

+79
-11
lines changed

doc/source/whatsnew/v0.25.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Other Enhancements
2222
- Indexing of ``DataFrame`` and ``Series`` now accepts zerodim ``np.ndarray`` (:issue:`24919`)
2323
- :meth:`Timestamp.replace` now supports the ``fold`` argument to disambiguate DST transition times (:issue:`25017`)
2424
- :meth:`DataFrame.at_time` and :meth:`Series.at_time` now support :meth:`datetime.time` objects with timezones (:issue:`24043`)
25+
- :meth:`DataFrame.set_index` now works for instances of ``abc.Iterator``, provided their output is of the same length as the calling frame (:issue:`22484`, :issue:`24984`)
2526
- :meth:`DatetimeIndex.union` now supports the ``sort`` argument. The behaviour of the sort parameter matches that of :meth:`Index.union` (:issue:`24994`)
2627
-
2728

pandas/compat/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def lfilter(*args, **kwargs):
137137
reload = reload
138138
Hashable = collections.abc.Hashable
139139
Iterable = collections.abc.Iterable
140+
Iterator = collections.abc.Iterator
140141
Mapping = collections.abc.Mapping
141142
MutableMapping = collections.abc.MutableMapping
142143
Sequence = collections.abc.Sequence
@@ -199,6 +200,7 @@ def get_range_parameters(data):
199200

200201
Hashable = collections.Hashable
201202
Iterable = collections.Iterable
203+
Iterator = collections.Iterator
202204
Mapping = collections.Mapping
203205
MutableMapping = collections.MutableMapping
204206
Sequence = collections.Sequence

pandas/core/frame.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from pandas import compat
3535
from pandas.compat import (range, map, zip, lmap, lzip, StringIO, u,
36-
PY36, raise_with_traceback,
36+
PY36, raise_with_traceback, Iterator,
3737
string_and_binary_types)
3838
from pandas.compat.numpy import function as nv
3939
from pandas.core.dtypes.cast import (
@@ -4025,7 +4025,8 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
40254025
This parameter can be either a single column key, a single array of
40264026
the same length as the calling DataFrame, or a list containing an
40274027
arbitrary combination of column keys and arrays. Here, "array"
4028-
encompasses :class:`Series`, :class:`Index` and ``np.ndarray``.
4028+
encompasses :class:`Series`, :class:`Index`, ``np.ndarray``, and
4029+
instances of :class:`abc.Iterator`.
40294030
drop : bool, default True
40304031
Delete columns to be used as the new index.
40314032
append : bool, default False
@@ -4104,6 +4105,32 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
41044105
if not isinstance(keys, list):
41054106
keys = [keys]
41064107

4108+
err_msg = ('The parameter "keys" may be a column key, one-dimensional '
4109+
'array, or a list containing only valid column keys and '
4110+
'one-dimensional arrays.')
4111+
4112+
missing = []
4113+
for col in keys:
4114+
if isinstance(col, (ABCIndexClass, ABCSeries, np.ndarray,
4115+
list, Iterator)):
4116+
# arrays are fine as long as they are one-dimensional
4117+
# iterators get converted to list below
4118+
if getattr(col, 'ndim', 1) != 1:
4119+
raise ValueError(err_msg)
4120+
else:
4121+
# everything else gets tried as a key; see GH 24969
4122+
try:
4123+
found = col in self.columns
4124+
except TypeError:
4125+
raise TypeError(err_msg + ' Received column of '
4126+
'type {}'.format(type(col)))
4127+
else:
4128+
if not found:
4129+
missing.append(col)
4130+
4131+
if missing:
4132+
raise KeyError('None of {} are in the columns'.format(missing))
4133+
41074134
if inplace:
41084135
frame = self
41094136
else:
@@ -4132,13 +4159,25 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
41324159
elif isinstance(col, (list, np.ndarray)):
41334160
arrays.append(col)
41344161
names.append(None)
4162+
elif isinstance(col, Iterator):
4163+
arrays.append(list(col))
4164+
names.append(None)
41354165
# from here, col can only be a column label
41364166
else:
41374167
arrays.append(frame[col]._values)
41384168
names.append(col)
41394169
if drop:
41404170
to_remove.append(col)
41414171

4172+
if len(arrays[-1]) != len(self):
4173+
# check newest element against length of calling frame, since
4174+
# ensure_index_from_sequences would not raise for append=False.
4175+
raise ValueError('Length mismatch: Expected {len_self} rows, '
4176+
'received array of length {len_col}'.format(
4177+
len_self=len(self),
4178+
len_col=len(arrays[-1])
4179+
))
4180+
41424181
index = ensure_index_from_sequences(arrays, names)
41434182

41444183
if verify_integrity and not index.is_unique:

pandas/tests/frame/test_alter_axes.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,10 @@ def test_set_index_pass_arrays(self, frame_of_index_cols,
178178
# MultiIndex constructor does not work directly on Series -> lambda
179179
# We also emulate a "constructor" for the label -> lambda
180180
# also test index name if append=True (name is duplicate here for A)
181-
@pytest.mark.parametrize('box2', [Series, Index, np.array, list,
181+
@pytest.mark.parametrize('box2', [Series, Index, np.array, list, iter,
182182
lambda x: MultiIndex.from_arrays([x]),
183183
lambda x: x.name])
184-
@pytest.mark.parametrize('box1', [Series, Index, np.array, list,
184+
@pytest.mark.parametrize('box1', [Series, Index, np.array, list, iter,
185185
lambda x: MultiIndex.from_arrays([x]),
186186
lambda x: x.name])
187187
@pytest.mark.parametrize('append, index_name', [(True, None),
@@ -195,6 +195,9 @@ def test_set_index_pass_arrays_duplicate(self, frame_of_index_cols, drop,
195195
keys = [box1(df['A']), box2(df['A'])]
196196
result = df.set_index(keys, drop=drop, append=append)
197197

198+
# if either box is iter, it has been consumed; re-read
199+
keys = [box1(df['A']), box2(df['A'])]
200+
198201
# need to adapt first drop for case that both keys are 'A' --
199202
# cannot drop the same column twice;
200203
# use "is" because == would give ambiguous Boolean error for containers
@@ -253,25 +256,48 @@ def test_set_index_raise_keys(self, frame_of_index_cols, drop, append):
253256
df.set_index(['A', df['A'], tuple(df['A'])],
254257
drop=drop, append=append)
255258

256-
@pytest.mark.xfail(reason='broken due to revert, see GH 25085')
257259
@pytest.mark.parametrize('append', [True, False])
258260
@pytest.mark.parametrize('drop', [True, False])
259-
@pytest.mark.parametrize('box', [set, iter, lambda x: (y for y in x)],
260-
ids=['set', 'iter', 'generator'])
261+
@pytest.mark.parametrize('box', [set], ids=['set'])
261262
def test_set_index_raise_on_type(self, frame_of_index_cols, box,
262263
drop, append):
263264
df = frame_of_index_cols
264265

265266
msg = 'The parameter "keys" may be a column key, .*'
266-
# forbidden type, e.g. set/iter/generator
267+
# forbidden type, e.g. set
267268
with pytest.raises(TypeError, match=msg):
268269
df.set_index(box(df['A']), drop=drop, append=append)
269270

270-
# forbidden type in list, e.g. set/iter/generator
271+
# forbidden type in list, e.g. set
271272
with pytest.raises(TypeError, match=msg):
272273
df.set_index(['A', df['A'], box(df['A'])],
273274
drop=drop, append=append)
274275

276+
# MultiIndex constructor does not work directly on Series -> lambda
277+
@pytest.mark.parametrize('box', [Series, Index, np.array, iter,
278+
lambda x: MultiIndex.from_arrays([x])],
279+
ids=['Series', 'Index', 'np.array',
280+
'iter', 'MultiIndex'])
281+
@pytest.mark.parametrize('length', [4, 6], ids=['too_short', 'too_long'])
282+
@pytest.mark.parametrize('append', [True, False])
283+
@pytest.mark.parametrize('drop', [True, False])
284+
def test_set_index_raise_on_len(self, frame_of_index_cols, box, length,
285+
drop, append):
286+
# GH 24984
287+
df = frame_of_index_cols # has length 5
288+
289+
values = np.random.randint(0, 10, (length,))
290+
291+
msg = 'Length mismatch: Expected 5 rows, received array of length.*'
292+
293+
# wrong length directly
294+
with pytest.raises(ValueError, match=msg):
295+
df.set_index(box(values), drop=drop, append=append)
296+
297+
# wrong length in list
298+
with pytest.raises(ValueError, match=msg):
299+
df.set_index(['A', df.A, box(values)], drop=drop, append=append)
300+
275301
def test_set_index_custom_label_type(self):
276302
# GH 24969
277303

@@ -341,7 +367,7 @@ def __repr__(self):
341367

342368
# missing key
343369
thing3 = Thing(['Three', 'pink'])
344-
msg = '.*' # due to revert, see GH 25085
370+
msg = r"frozenset\(\{'Three', 'pink'\}\)"
345371
with pytest.raises(KeyError, match=msg):
346372
# missing label directly
347373
df.set_index(thing3)
@@ -366,7 +392,7 @@ def __str__(self):
366392
thing2 = Thing('Two', 'blue')
367393
df = DataFrame([[0, 2], [1, 3]], columns=[thing1, thing2])
368394

369-
msg = 'unhashable type.*'
395+
msg = 'The parameter "keys" may be a column key, .*'
370396

371397
with pytest.raises(TypeError, match=msg):
372398
# use custom label directly

0 commit comments

Comments
 (0)