Skip to content

Commit bb50531

Browse files
authored
ENH: improved dtype inference for Index.map (#44609)
1 parent 0bf83d6 commit bb50531

File tree

7 files changed

+48
-48
lines changed

7 files changed

+48
-48
lines changed

doc/source/whatsnew/v1.4.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ Other enhancements
226226
``USFederalHolidayCalendar``. See also `Other API changes`_.
227227
- :meth:`.Rolling.var`, :meth:`.Expanding.var`, :meth:`.Rolling.std`, :meth:`.Expanding.std` now support `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`44461`)
228228
- :meth:`Series.info` has been added, for compatibility with :meth:`DataFrame.info` (:issue:`5167`)
229+
- :meth:`UInt64Index.map` now retains ``dtype`` where possible (:issue:`44609`)
230+
-
229231

230232

231233
.. ---------------------------------------------------------------------------

pandas/core/indexes/base.py

+10
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
can_hold_element,
6969
find_common_type,
7070
infer_dtype_from,
71+
maybe_cast_pointwise_result,
7172
validate_numeric_casting,
7273
)
7374
from pandas.core.dtypes.common import (
@@ -5977,6 +5978,15 @@ def map(self, mapper, na_action=None):
59775978
# empty
59785979
dtype = self.dtype
59795980

5981+
# e.g. if we are floating and new_values is all ints, then we
5982+
# don't want to cast back to floating. But if we are UInt64
5983+
# and new_values is all ints, we want to try.
5984+
same_dtype = lib.infer_dtype(new_values, skipna=False) == self.inferred_type
5985+
if same_dtype:
5986+
new_values = maybe_cast_pointwise_result(
5987+
new_values, self.dtype, same_dtype=same_dtype
5988+
)
5989+
59805990
if self._is_backward_compat_public_numeric_index and is_numeric_dtype(
59815991
new_values.dtype
59825992
):

pandas/tests/indexes/common.py

+8-22
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010

1111
from pandas.core.dtypes.common import (
1212
is_datetime64tz_dtype,
13-
is_float_dtype,
1413
is_integer_dtype,
15-
is_unsigned_integer_dtype,
1614
)
1715
from pandas.core.dtypes.dtypes import CategoricalDtype
1816

@@ -557,20 +555,9 @@ def test_map(self, simple_index):
557555
# callable
558556
idx = simple_index
559557

560-
# we don't infer UInt64
561-
if is_integer_dtype(idx.dtype):
562-
expected = idx.astype("int64")
563-
elif is_float_dtype(idx.dtype):
564-
expected = idx.astype("float64")
565-
if idx._is_backward_compat_public_numeric_index:
566-
# We get a NumericIndex back, not Float64Index
567-
expected = type(idx)(expected)
568-
else:
569-
expected = idx
570-
571558
result = idx.map(lambda x: x)
572559
# For RangeIndex we convert to Int64Index
573-
tm.assert_index_equal(result, expected, exact="equiv")
560+
tm.assert_index_equal(result, idx, exact="equiv")
574561

575562
@pytest.mark.parametrize(
576563
"mapper",
@@ -583,27 +570,26 @@ def test_map_dictlike(self, mapper, simple_index):
583570

584571
idx = simple_index
585572
if isinstance(idx, CategoricalIndex):
573+
# TODO(2.0): see if we can avoid skipping once
574+
# CategoricalIndex.reindex is removed.
586575
pytest.skip(f"skipping tests for {type(idx)}")
587576

588577
identity = mapper(idx.values, idx)
589578

590-
# we don't infer to UInt64 for a dict
591-
if is_unsigned_integer_dtype(idx.dtype) and isinstance(identity, dict):
592-
expected = idx.astype("int64")
593-
else:
594-
expected = idx
595-
596579
result = idx.map(identity)
597580
# For RangeIndex we convert to Int64Index
598-
tm.assert_index_equal(result, expected, exact="equiv")
581+
tm.assert_index_equal(result, idx, exact="equiv")
599582

600583
# empty mappable
584+
dtype = None
601585
if idx._is_backward_compat_public_numeric_index:
602586
new_index_cls = NumericIndex
587+
if idx.dtype.kind == "f":
588+
dtype = idx.dtype
603589
else:
604590
new_index_cls = Float64Index
605591

606-
expected = new_index_cls([np.nan] * len(idx))
592+
expected = new_index_cls([np.nan] * len(idx), dtype=dtype)
607593
result = idx.map(mapper(expected, idx))
608594
tm.assert_index_equal(result, expected)
609595

pandas/tests/indexes/multi/test_analytics.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,8 @@ def test_map(idx):
174174
# callable
175175
index = idx
176176

177-
# we don't infer UInt64
178-
if isinstance(index, UInt64Index):
179-
expected = index.astype("int64")
180-
else:
181-
expected = index
182-
183177
result = index.map(lambda x: x)
184-
tm.assert_index_equal(result, expected)
178+
tm.assert_index_equal(result, index)
185179

186180

187181
@pytest.mark.parametrize(

pandas/tests/indexes/numeric/test_numeric.py

+17
Original file line numberDiff line numberDiff line change
@@ -670,3 +670,20 @@ def test_float64_index_equals():
670670

671671
result = string_index.equals(float_index)
672672
assert result is False
673+
674+
675+
def test_map_dtype_inference_unsigned_to_signed():
676+
# GH#44609 cases where we don't retain dtype
677+
idx = UInt64Index([1, 2, 3])
678+
result = idx.map(lambda x: -x)
679+
expected = Int64Index([-1, -2, -3])
680+
tm.assert_index_equal(result, expected)
681+
682+
683+
def test_map_dtype_inference_overflows():
684+
# GH#44609 case where we have to upcast
685+
idx = NumericIndex(np.array([1, 2, 3], dtype=np.int8))
686+
result = idx.map(lambda x: x * 1000)
687+
# TODO: we could plausibly try to infer down to int16 here
688+
expected = NumericIndex([1000, 2000, 3000], dtype=np.int64)
689+
tm.assert_index_equal(result, expected)

pandas/tests/indexes/test_any_index.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
"""
44
import re
55

6-
import numpy as np
76
import pytest
87

9-
from pandas.core.dtypes.common import is_float_dtype
10-
118
import pandas._testing as tm
129

1310

@@ -49,16 +46,7 @@ def test_mutability(index):
4946
def test_map_identity_mapping(index):
5047
# GH#12766
5148
result = index.map(lambda x: x)
52-
if index._is_backward_compat_public_numeric_index:
53-
if is_float_dtype(index.dtype):
54-
expected = index.astype(np.float64)
55-
elif index.dtype == np.uint64:
56-
expected = index.astype(np.uint64)
57-
else:
58-
expected = index.astype(np.int64)
59-
else:
60-
expected = index
61-
tm.assert_index_equal(result, expected, exact="equiv")
49+
tm.assert_index_equal(result, index, exact="equiv")
6250

6351

6452
def test_wrong_number_names(index):

pandas/tests/indexes/test_base.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
period_range,
2626
)
2727
import pandas._testing as tm
28-
from pandas.api.types import is_float_dtype
2928
from pandas.core.api import (
3029
Float64Index,
3130
Int64Index,
@@ -535,11 +534,15 @@ def test_map_dictlike(self, index, mapper):
535534
# to match proper result coercion for uints
536535
expected = Index([])
537536
elif index._is_backward_compat_public_numeric_index:
538-
if is_float_dtype(index.dtype):
539-
exp_dtype = np.float64
540-
else:
541-
exp_dtype = np.int64
542-
expected = index._constructor(np.arange(len(index), 0, -1), dtype=exp_dtype)
537+
expected = index._constructor(
538+
np.arange(len(index), 0, -1), dtype=index.dtype
539+
)
540+
elif type(index) is Index and index.dtype != object:
541+
# i.e. EA-backed, for now just Nullable
542+
expected = Index(np.arange(len(index), 0, -1), dtype=index.dtype)
543+
elif index.dtype.kind == "u":
544+
# TODO: case where e.g. we cannot hold result in UInt8?
545+
expected = Index(np.arange(len(index), 0, -1), dtype=index.dtype)
543546
else:
544547
expected = Index(np.arange(len(index), 0, -1))
545548

0 commit comments

Comments
 (0)