Skip to content

Commit 5122a43

Browse files
committed
Clean-up & parametrise tests for df.set_index; better warnings
1 parent 904c4ae commit 5122a43

File tree

7 files changed

+629
-195
lines changed

7 files changed

+629
-195
lines changed

pandas/core/frame.py

+16
Original file line numberDiff line numberDiff line change
@@ -3973,6 +3973,22 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
39733973
3 2013 7 84
39743974
4 2014 10 31
39753975
"""
3976+
from pandas import Series
3977+
3978+
if not isinstance(keys, list):
3979+
keys = [keys]
3980+
3981+
col_labels = [x for x in keys
3982+
if not isinstance(x, (Series, Index, MultiIndex,
3983+
list, np.ndarray))]
3984+
if any(x not in self for x in col_labels):
3985+
missing = [x for x in col_labels if x not in self]
3986+
raise KeyError('{}'.format(missing))
3987+
elif len(set(col_labels)) < len(col_labels):
3988+
dup = Series(col_labels)
3989+
dup = list(dup.loc[dup.duplicated()])
3990+
raise ValueError('Passed duplicate column names '
3991+
'to keys: {dup}'.format(dup=dup))
39763992
vi = verify_integrity
39773993
return super(DataFrame, self).set_index(keys=keys, drop=drop,
39783994
append=append, inplace=inplace,

pandas/core/generic.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -737,9 +737,6 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
737737
from pandas import Series
738738

739739
inplace = validate_bool_kwarg(inplace, 'inplace')
740-
if not isinstance(keys, list):
741-
keys = [keys]
742-
743740
if inplace:
744741
obj = self
745742
else:
@@ -774,9 +771,7 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
774771
elif isinstance(col, (list, np.ndarray)):
775772
level = col
776773
names.append(None)
777-
elif isinstance(obj, Series):
778-
# col may not be a column label for Series case
779-
raise ValueError('asdf')
774+
# from here on, col must be a column label
780775
else:
781776
level = obj[col]._values
782777
names.append(col)

pandas/core/series.py

+10
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,16 @@ def set_index(self, arrays, append=False, inplace=False,
11201120
--------
11211121
>>> ...
11221122
"""
1123+
1124+
if (not isinstance(keys, (Series, Index, MultiIndex, list, np.ndarray))
1125+
or (isinstance(keys, list)
1126+
and any(not isinstance(x, (Series, Index, MultiIndex,
1127+
list, np.ndarray))
1128+
for x in keys))):
1129+
raise ValueError('arrays must be Series, Index, MultiIndex, '
1130+
'np.ndarray or list containing containing only'
1131+
'Series, Index, MultiIndex, list, np.ndarray')
1132+
11231133
return super(Series, self).set_index(keys=arrays, drop=False,
11241134
append=append, inplace=inplace,
11251135
verify_integrity=verify_integrity)

pandas/tests/frame/common.py

+9
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ def simple(self):
103103
return pd.DataFrame(arr, columns=['one', 'two', 'three'],
104104
index=['a', 'b', 'c'])
105105

106+
@cache_readonly
107+
def dummy(self):
108+
df = pd.DataFrame({'A': ['foo', 'foo', 'foo', 'bar', 'bar'],
109+
'B': ['one', 'two', 'three', 'one', 'two'],
110+
'C': ['a', 'b', 'c', 'd', 'e'],
111+
'D': np.random.randn(5),
112+
'E': np.random.randn(5)})
113+
return df
114+
106115
# self.ts3 = tm.makeTimeSeries()[-5:]
107116
# self.ts4 = tm.makeTimeSeries()[1:-1]
108117

0 commit comments

Comments
 (0)