Skip to content

Commit 8660e7e

Browse files
topper-123feefladder
authored andcommitted
ENH: NumericIndex for any numpy int/uint/float dtype (pandas-dev#41153)
1 parent 67b61c9 commit 8660e7e

16 files changed

+321
-103
lines changed

pandas/_libs/join.pyx

+3
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ ctypedef fused join_t:
265265
int16_t
266266
int32_t
267267
int64_t
268+
uint8_t
269+
uint16_t
270+
uint32_t
268271
uint64_t
269272

270273

pandas/_testing/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
use_numexpr,
106106
with_csv_dialect,
107107
)
108+
from pandas.core.api import NumericIndex
108109
from pandas.core.arrays import (
109110
DatetimeArray,
110111
PandasArray,
@@ -314,7 +315,7 @@ def makeNumericIndex(k=10, name=None, *, dtype):
314315
else:
315316
raise NotImplementedError(f"wrong dtype {dtype}")
316317

317-
return Index(values, dtype=dtype, name=name)
318+
return NumericIndex(values, dtype=dtype, name=name)
318319

319320

320321
def makeIntIndex(k=10, name=None):

pandas/conftest.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,16 @@ def _create_mi_with_dt64tz_level():
460460
"uint": tm.makeUIntIndex(100),
461461
"range": tm.makeRangeIndex(100),
462462
"float": tm.makeFloatIndex(100),
463+
"num_int64": tm.makeNumericIndex(100, dtype="int64"),
464+
"num_int32": tm.makeNumericIndex(100, dtype="int32"),
465+
"num_int16": tm.makeNumericIndex(100, dtype="int16"),
466+
"num_int8": tm.makeNumericIndex(100, dtype="int8"),
467+
"num_uint64": tm.makeNumericIndex(100, dtype="uint64"),
468+
"num_uint32": tm.makeNumericIndex(100, dtype="uint32"),
469+
"num_uint16": tm.makeNumericIndex(100, dtype="uint16"),
470+
"num_uint8": tm.makeNumericIndex(100, dtype="uint8"),
471+
"num_float64": tm.makeNumericIndex(100, dtype="float64"),
472+
"num_float32": tm.makeNumericIndex(100, dtype="float32"),
463473
"bool": tm.makeBoolIndex(10),
464474
"categorical": tm.makeCategoricalIndex(100),
465475
"interval": tm.makeIntervalIndex(100),
@@ -511,7 +521,10 @@ def index_flat(request):
511521
params=[
512522
key
513523
for key in indices_dict
514-
if key not in ["int", "uint", "range", "empty", "repeats"]
524+
if not (
525+
key in ["int", "uint", "range", "empty", "repeats"]
526+
or key.startswith("num_")
527+
)
515528
and not isinstance(indices_dict[key], MultiIndex)
516529
]
517530
)

pandas/core/api.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
Int64Index,
5858
IntervalIndex,
5959
MultiIndex,
60+
NumericIndex,
6061
PeriodIndex,
6162
RangeIndex,
6263
TimedeltaIndex,

pandas/core/dtypes/generic.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _check(cls, inst) -> bool:
100100
"rangeindex",
101101
"float64index",
102102
"uint64index",
103+
"numericindex",
103104
"multiindex",
104105
"datetimeindex",
105106
"timedeltaindex",

pandas/core/indexes/base.py

+17
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
is_interval_dtype,
8282
is_iterator,
8383
is_list_like,
84+
is_numeric_dtype,
8485
is_object_dtype,
8586
is_scalar,
8687
is_signed_integer_dtype,
@@ -360,6 +361,11 @@ def _outer_indexer(
360361
_can_hold_na: bool = True
361362
_can_hold_strings: bool = True
362363

364+
# Whether this index is a NumericIndex, but not a Int64Index, Float64Index,
365+
# UInt64Index or RangeIndex. Needed for backwards compat. Remove this attribute and
366+
# associated code in pandas 2.0.
367+
_is_backward_compat_public_numeric_index: bool = False
368+
363369
_engine_type: type[libindex.IndexEngine] = libindex.ObjectEngine
364370
# whether we support partial string indexing. Overridden
365371
# in DatetimeIndex and PeriodIndex
@@ -437,6 +443,12 @@ def __new__(
437443
return Index._simple_new(data, name=name)
438444

439445
# index-like
446+
elif (
447+
isinstance(data, Index)
448+
and data._is_backward_compat_public_numeric_index
449+
and dtype is None
450+
):
451+
return data._constructor(data, name=name, copy=copy)
440452
elif isinstance(data, (np.ndarray, Index, ABCSeries)):
441453

442454
if isinstance(data, ABCMultiIndex):
@@ -5726,6 +5738,11 @@ def map(self, mapper, na_action=None):
57265738
# empty
57275739
attributes["dtype"] = self.dtype
57285740

5741+
if self._is_backward_compat_public_numeric_index and is_numeric_dtype(
5742+
new_values.dtype
5743+
):
5744+
return self._constructor(new_values, **attributes)
5745+
57295746
return Index(new_values, **attributes)
57305747

57315748
# TODO: De-duplicate with map, xref GH#32349

pandas/core/indexes/category.py

+25
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pandas.core.dtypes.common import (
2222
is_categorical_dtype,
2323
is_scalar,
24+
pandas_dtype,
2425
)
2526
from pandas.core.dtypes.missing import (
2627
is_valid_na_for_dtype,
@@ -280,6 +281,30 @@ def _is_dtype_compat(self, other) -> Categorical:
280281

281282
return other
282283

284+
@doc(Index.astype)
285+
def astype(self, dtype: Dtype, copy: bool = True) -> Index:
286+
from pandas.core.api import NumericIndex
287+
288+
dtype = pandas_dtype(dtype)
289+
290+
categories = self.categories
291+
# the super method always returns Int64Index, UInt64Index and Float64Index
292+
# but if the categories are a NumericIndex with dtype float32, we want to
293+
# return an index with the same dtype as self.categories.
294+
if categories._is_backward_compat_public_numeric_index:
295+
assert isinstance(categories, NumericIndex) # mypy complaint fix
296+
try:
297+
categories._validate_dtype(dtype)
298+
except ValueError:
299+
pass
300+
else:
301+
new_values = self._data.astype(dtype, copy=copy)
302+
# pass copy=False because any copying has been done in the
303+
# _data.astype call above
304+
return categories._constructor(new_values, name=self.name, copy=False)
305+
306+
return super().astype(dtype, copy=copy)
307+
283308
def equals(self, other: object) -> bool:
284309
"""
285310
Determine if two CategoricalIndex objects contain the same elements.

pandas/core/indexes/numeric.py

+42-9
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class NumericIndex(Index):
9797
)
9898
_is_numeric_dtype = True
9999
_can_hold_strings = False
100+
_is_backward_compat_public_numeric_index: bool = True
100101

101102
@cache_readonly
102103
def _can_hold_na(self) -> bool:
@@ -165,7 +166,15 @@ def _ensure_array(cls, data, dtype, copy: bool):
165166
dtype = cls._ensure_dtype(dtype)
166167

167168
if copy or not is_dtype_equal(data.dtype, dtype):
168-
subarr = np.array(data, dtype=dtype, copy=copy)
169+
# TODO: the try/except below is because it's difficult to predict the error
170+
# and/or error message from different combinations of data and dtype.
171+
# Efforts to avoid this try/except welcome.
172+
# See https://github.com/pandas-dev/pandas/pull/41153#discussion_r676206222
173+
try:
174+
subarr = np.array(data, dtype=dtype, copy=copy)
175+
cls._validate_dtype(subarr.dtype)
176+
except (TypeError, ValueError):
177+
raise ValueError(f"data is not compatible with {cls.__name__}")
169178
cls._assert_safe_casting(data, subarr)
170179
else:
171180
subarr = data
@@ -189,12 +198,24 @@ def _validate_dtype(cls, dtype: Dtype | None) -> None:
189198
)
190199

191200
@classmethod
192-
def _ensure_dtype(
193-
cls,
194-
dtype: Dtype | None,
195-
) -> np.dtype | None:
196-
"""Ensure int64 dtype for Int64Index, etc. Assumed dtype is validated."""
197-
return cls._default_dtype
201+
def _ensure_dtype(cls, dtype: Dtype | None) -> np.dtype | None:
202+
"""
203+
Ensure int64 dtype for Int64Index etc. but allow int32 etc. for NumericIndex.
204+
205+
Assumes dtype has already been validated.
206+
"""
207+
if dtype is None:
208+
return cls._default_dtype
209+
210+
dtype = pandas_dtype(dtype)
211+
assert isinstance(dtype, np.dtype)
212+
213+
if cls._is_backward_compat_public_numeric_index:
214+
# dtype for NumericIndex
215+
return dtype
216+
else:
217+
# dtype for Int64Index, UInt64Index etc. Needed for backwards compat.
218+
return cls._default_dtype
198219

199220
def __contains__(self, key) -> bool:
200221
"""
@@ -214,8 +235,8 @@ def __contains__(self, key) -> bool:
214235

215236
@doc(Index.astype)
216237
def astype(self, dtype, copy=True):
238+
dtype = pandas_dtype(dtype)
217239
if is_float_dtype(self.dtype):
218-
dtype = pandas_dtype(dtype)
219240
if needs_i8_conversion(dtype):
220241
raise TypeError(
221242
f"Cannot convert Float64Index to dtype {dtype}; integer "
@@ -225,7 +246,16 @@ def astype(self, dtype, copy=True):
225246
# TODO(jreback); this can change once we have an EA Index type
226247
# GH 13149
227248
arr = astype_nansafe(self._values, dtype=dtype)
228-
return Int64Index(arr, name=self.name)
249+
if isinstance(self, Float64Index):
250+
return Int64Index(arr, name=self.name)
251+
else:
252+
return NumericIndex(arr, name=self.name, dtype=dtype)
253+
elif self._is_backward_compat_public_numeric_index:
254+
# this block is needed so e.g. NumericIndex[int8].astype("int32") returns
255+
# NumericIndex[int32] and not Int64Index with dtype int64.
256+
# When Int64Index etc. are removed from the code base, removed this also.
257+
if not is_extension_array_dtype(dtype) and is_numeric_dtype(dtype):
258+
return self._constructor(self, dtype=dtype, copy=copy)
229259

230260
return super().astype(dtype, copy=copy)
231261

@@ -335,6 +365,8 @@ class IntegerIndex(NumericIndex):
335365
This is an abstract class for Int64Index, UInt64Index.
336366
"""
337367

368+
_is_backward_compat_public_numeric_index: bool = False
369+
338370
@property
339371
def asi8(self) -> np.ndarray:
340372
# do not cache or you'll create a memory leak
@@ -399,3 +431,4 @@ class Float64Index(NumericIndex):
399431
_engine_type = libindex.Float64Engine
400432
_default_dtype = np.dtype(np.float64)
401433
_dtype_validation_metadata = (is_float_dtype, "float")
434+
_is_backward_compat_public_numeric_index: bool = False

pandas/core/indexes/range.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class RangeIndex(NumericIndex):
101101
_engine_type = libindex.Int64Engine
102102
_dtype_validation_metadata = (is_signed_integer_dtype, "signed integer")
103103
_range: range
104+
_is_backward_compat_public_numeric_index: bool = False
104105

105106
# --------------------------------------------------------------------
106107
# Constructors

pandas/tests/base/test_unique.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import pandas as pd
1212
import pandas._testing as tm
13+
from pandas.core.api import NumericIndex
1314
from pandas.tests.base.common import allow_na_ops
1415

1516

@@ -24,6 +25,9 @@ def test_unique(index_or_series_obj):
2425
expected = pd.MultiIndex.from_tuples(unique_values)
2526
expected.names = obj.names
2627
tm.assert_index_equal(result, expected, exact=True)
28+
elif isinstance(obj, pd.Index) and obj._is_backward_compat_public_numeric_index:
29+
expected = NumericIndex(unique_values, dtype=obj.dtype)
30+
tm.assert_index_equal(result, expected, exact=True)
2731
elif isinstance(obj, pd.Index):
2832
expected = pd.Index(unique_values, dtype=obj.dtype)
2933
if is_datetime64tz_dtype(obj.dtype):
@@ -62,7 +66,10 @@ def test_unique_null(null_obj, index_or_series_obj):
6266
unique_values_not_null = [val for val in unique_values_raw if not pd.isnull(val)]
6367
unique_values = [null_obj] + unique_values_not_null
6468

65-
if isinstance(obj, pd.Index):
69+
if isinstance(obj, pd.Index) and obj._is_backward_compat_public_numeric_index:
70+
expected = NumericIndex(unique_values, dtype=obj.dtype)
71+
tm.assert_index_equal(result, expected, exact=True)
72+
elif isinstance(obj, pd.Index):
6673
expected = pd.Index(unique_values, dtype=obj.dtype)
6774
if is_datetime64tz_dtype(obj.dtype):
6875
result = result.normalize()

0 commit comments

Comments
 (0)