Skip to content

Commit 1203236

Browse files
jbrockmendelluckyvs1
authored andcommitted
BUG: silently ignoring dtype kwarg in Index.__new__ (pandas-dev#38879)
1 parent c4e9572 commit 1203236

File tree

5 files changed

+146
-94
lines changed

5 files changed

+146
-94
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ ExtensionArray
313313

314314
Other
315315
^^^^^
316-
316+
- Bug in :class:`Index` constructor sometimes silently ignorning a a specified ``dtype`` (:issue:`38879`)
317317
-
318318
-
319319

pandas/core/dtypes/common.py

+13
Original file line numberDiff line numberDiff line change
@@ -1529,6 +1529,19 @@ def is_extension_array_dtype(arr_or_dtype) -> bool:
15291529
return isinstance(dtype, ExtensionDtype) or registry.find(dtype) is not None
15301530

15311531

1532+
def is_ea_or_datetimelike_dtype(dtype: Optional[DtypeObj]) -> bool:
1533+
"""
1534+
Check for ExtensionDtype, datetime64 dtype, or timedelta64 dtype.
1535+
1536+
Notes
1537+
-----
1538+
Checks only for dtype objects, not dtype-castable strings or types.
1539+
"""
1540+
return isinstance(dtype, ExtensionDtype) or (
1541+
isinstance(dtype, np.dtype) and dtype.kind in ["m", "M"]
1542+
)
1543+
1544+
15321545
def is_complex_dtype(arr_or_dtype) -> bool:
15331546
"""
15341547
Check whether the provided array or dtype is of a complex dtype.

pandas/core/indexes/base.py

+54-88
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
ensure_platform_int,
4545
is_bool_dtype,
4646
is_categorical_dtype,
47-
is_datetime64_any_dtype,
4847
is_dtype_equal,
48+
is_ea_or_datetimelike_dtype,
4949
is_extension_array_dtype,
5050
is_float,
5151
is_float_dtype,
@@ -56,10 +56,8 @@
5656
is_iterator,
5757
is_list_like,
5858
is_object_dtype,
59-
is_period_dtype,
6059
is_scalar,
6160
is_signed_integer_dtype,
62-
is_timedelta64_dtype,
6361
is_unsigned_integer_dtype,
6462
needs_i8_conversion,
6563
pandas_dtype,
@@ -69,6 +67,7 @@
6967
from pandas.core.dtypes.dtypes import (
7068
CategoricalDtype,
7169
DatetimeTZDtype,
70+
ExtensionDtype,
7271
IntervalDtype,
7372
PeriodDtype,
7473
)
@@ -87,6 +86,7 @@
8786
import pandas.core.algorithms as algos
8887
from pandas.core.arrays import Categorical, ExtensionArray
8988
from pandas.core.arrays.datetimes import tz_to_dtype, validate_tz_from_dtype
89+
from pandas.core.arrays.sparse import SparseDtype
9090
from pandas.core.base import IndexOpsMixin, PandasObject
9191
import pandas.core.common as com
9292
from pandas.core.construction import extract_array
@@ -286,44 +286,32 @@ def __new__(
286286

287287
# range
288288
if isinstance(data, RangeIndex):
289-
return RangeIndex(start=data, copy=copy, dtype=dtype, name=name)
289+
result = RangeIndex(start=data, copy=copy, name=name)
290+
if dtype is not None:
291+
return result.astype(dtype, copy=False)
292+
return result
290293
elif isinstance(data, range):
291-
return RangeIndex.from_range(data, dtype=dtype, name=name)
292-
293-
# categorical
294-
elif is_categorical_dtype(data_dtype) or is_categorical_dtype(dtype):
295-
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
296-
from pandas.core.indexes.category import CategoricalIndex
297-
298-
return _maybe_asobject(dtype, CategoricalIndex, data, copy, name, **kwargs)
299-
300-
# interval
301-
elif is_interval_dtype(data_dtype) or is_interval_dtype(dtype):
302-
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
303-
from pandas.core.indexes.interval import IntervalIndex
304-
305-
return _maybe_asobject(dtype, IntervalIndex, data, copy, name, **kwargs)
306-
307-
elif is_datetime64_any_dtype(data_dtype) or is_datetime64_any_dtype(dtype):
308-
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
309-
from pandas import DatetimeIndex
310-
311-
return _maybe_asobject(dtype, DatetimeIndex, data, copy, name, **kwargs)
312-
313-
elif is_timedelta64_dtype(data_dtype) or is_timedelta64_dtype(dtype):
314-
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
315-
from pandas import TimedeltaIndex
316-
317-
return _maybe_asobject(dtype, TimedeltaIndex, data, copy, name, **kwargs)
318-
319-
elif is_period_dtype(data_dtype) or is_period_dtype(dtype):
320-
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
321-
from pandas import PeriodIndex
294+
result = RangeIndex.from_range(data, name=name)
295+
if dtype is not None:
296+
return result.astype(dtype, copy=False)
297+
return result
322298

323-
return _maybe_asobject(dtype, PeriodIndex, data, copy, name, **kwargs)
299+
if is_ea_or_datetimelike_dtype(dtype):
300+
# non-EA dtype indexes have special casting logic, so we punt here
301+
klass = cls._dtype_to_subclass(dtype)
302+
if klass is not Index:
303+
return klass(data, dtype=dtype, copy=copy, name=name, **kwargs)
304+
305+
if is_ea_or_datetimelike_dtype(data_dtype):
306+
klass = cls._dtype_to_subclass(data_dtype)
307+
if klass is not Index:
308+
result = klass(data, copy=copy, name=name, **kwargs)
309+
if dtype is not None:
310+
return result.astype(dtype, copy=False)
311+
return result
324312

325313
# extension dtype
326-
elif is_extension_array_dtype(data_dtype) or is_extension_array_dtype(dtype):
314+
if is_extension_array_dtype(data_dtype) or is_extension_array_dtype(dtype):
327315
if not (dtype is None or is_object_dtype(dtype)):
328316
# coerce to the provided dtype
329317
ea_cls = dtype.construct_array_type()
@@ -407,26 +395,38 @@ def _ensure_array(cls, data, dtype, copy: bool):
407395
def _dtype_to_subclass(cls, dtype: DtypeObj):
408396
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
409397

410-
if isinstance(dtype, DatetimeTZDtype) or dtype == np.dtype("M8[ns]"):
398+
if isinstance(dtype, ExtensionDtype):
399+
if isinstance(dtype, DatetimeTZDtype):
400+
from pandas import DatetimeIndex
401+
402+
return DatetimeIndex
403+
elif isinstance(dtype, CategoricalDtype):
404+
from pandas import CategoricalIndex
405+
406+
return CategoricalIndex
407+
elif isinstance(dtype, IntervalDtype):
408+
from pandas import IntervalIndex
409+
410+
return IntervalIndex
411+
elif isinstance(dtype, PeriodDtype):
412+
from pandas import PeriodIndex
413+
414+
return PeriodIndex
415+
416+
elif isinstance(dtype, SparseDtype):
417+
return cls._dtype_to_subclass(dtype.subtype)
418+
419+
return Index
420+
421+
if dtype.kind == "M":
411422
from pandas import DatetimeIndex
412423

413424
return DatetimeIndex
414-
elif dtype == "m8[ns]":
425+
426+
elif dtype.kind == "m":
415427
from pandas import TimedeltaIndex
416428

417429
return TimedeltaIndex
418-
elif isinstance(dtype, CategoricalDtype):
419-
from pandas import CategoricalIndex
420-
421-
return CategoricalIndex
422-
elif isinstance(dtype, IntervalDtype):
423-
from pandas import IntervalIndex
424-
425-
return IntervalIndex
426-
elif isinstance(dtype, PeriodDtype):
427-
from pandas import PeriodIndex
428-
429-
return PeriodIndex
430430

431431
elif is_float_dtype(dtype):
432432
from pandas import Float64Index
@@ -445,6 +445,9 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
445445
# NB: assuming away MultiIndex
446446
return Index
447447

448+
elif issubclass(dtype.type, (str, bool, np.bool_)):
449+
return Index
450+
448451
raise NotImplementedError(dtype)
449452

450453
"""
@@ -6253,43 +6256,6 @@ def _try_convert_to_int_array(
62536256
raise ValueError
62546257

62556258

6256-
def _maybe_asobject(dtype, klass, data, copy: bool, name: Label, **kwargs):
6257-
"""
6258-
If an object dtype was specified, create the non-object Index
6259-
and then convert it to object.
6260-
6261-
Parameters
6262-
----------
6263-
dtype : np.dtype, ExtensionDtype, str
6264-
klass : Index subclass
6265-
data : list-like
6266-
copy : bool
6267-
name : hashable
6268-
**kwargs
6269-
6270-
Returns
6271-
-------
6272-
Index
6273-
6274-
Notes
6275-
-----
6276-
We assume that calling .astype(object) on this klass will make a copy.
6277-
"""
6278-
6279-
# GH#23524 passing `dtype=object` to DatetimeIndex is invalid,
6280-
# will raise in the where `data` is already tz-aware. So
6281-
# we leave it out of this step and cast to object-dtype after
6282-
# the DatetimeIndex construction.
6283-
6284-
if is_dtype_equal(_o_dtype, dtype):
6285-
# Note we can pass copy=False because the .astype below
6286-
# will always make a copy
6287-
index = klass(data, copy=False, name=name, **kwargs)
6288-
return index.astype(object)
6289-
6290-
return klass(data, dtype=dtype, copy=copy, name=name, **kwargs)
6291-
6292-
62936259
def get_unanimous_names(*indexes: Index) -> Tuple[Label, ...]:
62946260
"""
62956261
Return common name if all indices agree, otherwise None (level-by-level).

pandas/tests/indexes/ranges/test_constructors.py

-5
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,6 @@ def test_constructor_range(self):
114114
expected = RangeIndex(1, 5, 2)
115115
tm.assert_index_equal(result, expected, exact=True)
116116

117-
with pytest.raises(
118-
ValueError,
119-
match="Incorrect `dtype` passed: expected signed integer, received float64",
120-
):
121-
Index(range(1, 5, 2), dtype="float64")
122117
msg = r"^from_range\(\) got an unexpected keyword argument"
123118
with pytest.raises(TypeError, match=msg):
124119
RangeIndex.from_range(range(10), copy=True)

pandas/tests/indexes/test_index_new.py

+78
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88

99
from pandas import (
1010
NA,
11+
Categorical,
1112
CategoricalIndex,
1213
DatetimeIndex,
1314
Index,
1415
Int64Index,
16+
IntervalIndex,
1517
MultiIndex,
1618
NaT,
1719
PeriodIndex,
1820
Series,
1921
TimedeltaIndex,
2022
Timestamp,
2123
UInt64Index,
24+
date_range,
2225
period_range,
26+
timedelta_range,
2327
)
2428
import pandas._testing as tm
2529

@@ -122,6 +126,80 @@ def test_constructor_mixed_nat_objs_infers_object(self, swap_objs):
122126
tm.assert_index_equal(Index(np.array(data, dtype=object)), expected)
123127

124128

129+
class TestDtypeEnforced:
130+
# check we don't silently ignore the dtype keyword
131+
132+
@pytest.mark.parametrize("dtype", [object, "float64", "uint64", "category"])
133+
def test_constructor_range_values_mismatched_dtype(self, dtype):
134+
rng = Index(range(5))
135+
136+
result = Index(rng, dtype=dtype)
137+
assert result.dtype == dtype
138+
139+
result = Index(range(5), dtype=dtype)
140+
assert result.dtype == dtype
141+
142+
@pytest.mark.parametrize("dtype", [object, "float64", "uint64", "category"])
143+
def test_constructor_categorical_values_mismatched_non_ea_dtype(self, dtype):
144+
cat = Categorical([1, 2, 3])
145+
146+
result = Index(cat, dtype=dtype)
147+
assert result.dtype == dtype
148+
149+
def test_constructor_categorical_values_mismatched_dtype(self):
150+
dti = date_range("2016-01-01", periods=3)
151+
cat = Categorical(dti)
152+
result = Index(cat, dti.dtype)
153+
tm.assert_index_equal(result, dti)
154+
155+
dti2 = dti.tz_localize("Asia/Tokyo")
156+
cat2 = Categorical(dti2)
157+
result = Index(cat2, dti2.dtype)
158+
tm.assert_index_equal(result, dti2)
159+
160+
ii = IntervalIndex.from_breaks(range(5))
161+
cat3 = Categorical(ii)
162+
result = Index(cat3, dtype=ii.dtype)
163+
tm.assert_index_equal(result, ii)
164+
165+
def test_constructor_ea_values_mismatched_categorical_dtype(self):
166+
dti = date_range("2016-01-01", periods=3)
167+
result = Index(dti, dtype="category")
168+
expected = CategoricalIndex(dti)
169+
tm.assert_index_equal(result, expected)
170+
171+
dti2 = date_range("2016-01-01", periods=3, tz="US/Pacific")
172+
result = Index(dti2, dtype="category")
173+
expected = CategoricalIndex(dti2)
174+
tm.assert_index_equal(result, expected)
175+
176+
def test_constructor_period_values_mismatched_dtype(self):
177+
pi = period_range("2016-01-01", periods=3, freq="D")
178+
result = Index(pi, dtype="category")
179+
expected = CategoricalIndex(pi)
180+
tm.assert_index_equal(result, expected)
181+
182+
def test_constructor_timedelta64_values_mismatched_dtype(self):
183+
# check we don't silently ignore the dtype keyword
184+
tdi = timedelta_range("4 Days", periods=5)
185+
result = Index(tdi, dtype="category")
186+
expected = CategoricalIndex(tdi)
187+
tm.assert_index_equal(result, expected)
188+
189+
def test_constructor_interval_values_mismatched_dtype(self):
190+
dti = date_range("2016-01-01", periods=3)
191+
ii = IntervalIndex.from_breaks(dti)
192+
result = Index(ii, dtype="category")
193+
expected = CategoricalIndex(ii)
194+
tm.assert_index_equal(result, expected)
195+
196+
def test_constructor_datetime64_values_mismatched_period_dtype(self):
197+
dti = date_range("2016-01-01", periods=3)
198+
result = Index(dti, dtype="Period[D]")
199+
expected = dti.to_period("D")
200+
tm.assert_index_equal(result, expected)
201+
202+
125203
class TestIndexConstructorUnwrapping:
126204
# Test passing different arraylike values to pd.Index
127205

0 commit comments

Comments
 (0)