Skip to content

Commit cf6c831

Browse files
committed
Allow index.map() to accept series and dictionary inputs in addition to functional inputs
1 parent 2fa33fb commit cf6c831

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

pandas/indexes/base.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pandas.compat.numpy import function as nv
1515
from pandas import compat
1616

17-
1817
from pandas.types.generic import ABCSeries, ABCMultiIndex, ABCPeriodIndex
1918
from pandas.types.missing import isnull, array_equivalent
2019
from pandas.types.common import (_ensure_int64,
@@ -2492,7 +2491,7 @@ def map(self, mapper):
24922491
24932492
Parameters
24942493
----------
2495-
mapper : callable
2494+
mapper : function, dict, or Series
24962495
Function to be applied.
24972496
24982497
Returns
@@ -2504,7 +2503,15 @@ def map(self, mapper):
25042503
25052504
"""
25062505
from .multi import MultiIndex
2507-
mapped_values = self._arrmap(self.values, mapper)
2506+
2507+
if isinstance(mapper, ABCSeries):
2508+
indexer = mapper.index.get_indexer(self._values)
2509+
mapped_values = algos.take_1d(mapper.values, indexer)
2510+
else:
2511+
if isinstance(mapper, dict):
2512+
mapper = mapper.get
2513+
mapped_values = self._arrmap(self._values, mapper)
2514+
25082515
attributes = self._get_attributes_dict()
25092516
if mapped_values.size and isinstance(mapped_values[0], tuple):
25102517
return MultiIndex.from_tuples(mapped_values,

pandas/tests/indexes/test_base.py

+46
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,52 @@ def test_map_tseries_indices_return_index(self):
811811
exp = Index(range(24), name='hourly')
812812
tm.assert_index_equal(exp, date_index.map(lambda x: x.hour))
813813

814+
def test_map_with_series_all_indices(self):
815+
expected = Index(['foo', 'bar', 'baz'])
816+
mapper = Series(expected.values, index=[0, 1, 2])
817+
self.assert_index_equal(tm.makeIntIndex(3).map(mapper), expected)
818+
819+
# GH 12766
820+
# special = []
821+
special = ['catIndex']
822+
823+
for name in special:
824+
orig_values = ['a', 'B', 1, 'a']
825+
new_values = ['one', 2, 3.0, 'one']
826+
cur_index = CategoricalIndex(orig_values, name='XXX')
827+
mapper = pd.Series(new_values[:-1], index=orig_values[:-1])
828+
expected = CategoricalIndex(new_values, name='XXX')
829+
output = cur_index.map(mapper)
830+
self.assert_numpy_array_equal(expected.values.get_values(), output.values.get_values())
831+
self.assert_equal(expected.name, output.name)
832+
833+
834+
for name in list(set(self.indices.keys()) - set(special)):
835+
cur_index = self.indices[name]
836+
expected = Index(np.arange(len(cur_index), 0, -1))
837+
mapper = pd.Series(expected.values, index=cur_index)
838+
print(name)
839+
output = cur_index.map(mapper)
840+
self.assert_index_equal(expected, cur_index.map(mapper))
841+
842+
def test_map_with_categorical_series(self):
843+
# GH 12756
844+
a = Index([1, 2, 3, 4])
845+
b = Series(["even", "odd", "even", "odd"], dtype="category")
846+
c = Series(["even", "odd", "even", "odd"])
847+
848+
exp = CategoricalIndex(["odd", "even", "odd", np.nan])
849+
self.assert_index_equal(a.map(b), exp)
850+
exp = Index(["odd", "even", "odd", np.nan])
851+
self.assert_index_equal(a.map(c), exp)
852+
853+
def test_map_with_series_missing_values(self):
854+
# GH 12756
855+
expected = Index([2., np.nan, 'foo'])
856+
mapper = Series(['foo', 2., 'baz'], index=[0, 2, -1])
857+
output = Index([2, 1, 0]).map(mapper)
858+
self.assert_index_equal(output, expected)
859+
814860
def test_append_multiple(self):
815861
index = Index(['a', 'b', 'c', 'd', 'e', 'f'])
816862

pandas/tests/indexes/test_category.py

+6
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ def f(x):
233233
ordered=False)
234234
tm.assert_index_equal(result, exp)
235235

236+
result = ci.map(pd.Series([10, 20, 30], index=['A', 'B', 'C']))
237+
tm.assert_index_equal(result, exp)
238+
239+
result = ci.map({'A': 10, 'B': 20, 'C': 30})
240+
tm.assert_index_equal(result, exp)
241+
236242
def test_where(self):
237243
i = self.create_index()
238244
result = i.where(notnull(i))

0 commit comments

Comments
 (0)