Skip to content

Commit 882a2c1

Browse files
jbrockmendelJulianWgs
authored andcommitted
REF: simplify indexes.base._maybe_cast_data_without_dtype (pandas-dev#41881)
1 parent ddc5423 commit 882a2c1

File tree

5 files changed

+82
-86
lines changed

5 files changed

+82
-86
lines changed

pandas/_libs/lib.pyx

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,21 @@ cpdef bint is_datetime64_array(ndarray values):
19541954
return validator.validate(values)
19551955

19561956

1957+
@cython.internal
1958+
cdef class AnyDatetimeValidator(DatetimeValidator):
1959+
cdef inline bint is_value_typed(self, object value) except -1:
1960+
return util.is_datetime64_object(value) or (
1961+
PyDateTime_Check(value) and value.tzinfo is None
1962+
)
1963+
1964+
1965+
cdef bint is_datetime_or_datetime64_array(ndarray values):
1966+
cdef:
1967+
AnyDatetimeValidator validator = AnyDatetimeValidator(len(values),
1968+
skipna=True)
1969+
return validator.validate(values)
1970+
1971+
19571972
# Note: only python-exposed for tests
19581973
def is_datetime_with_singletz_array(values: ndarray) -> bool:
19591974
"""
@@ -1966,22 +1981,25 @@ def is_datetime_with_singletz_array(values: ndarray) -> bool:
19661981

19671982
if n == 0:
19681983
return False
1984+
19691985
# Get a reference timezone to compare with the rest of the tzs in the array
19701986
for i in range(n):
19711987
base_val = values[i]
1972-
if base_val is not NaT:
1988+
if base_val is not NaT and base_val is not None and not util.is_nan(base_val):
19731989
base_tz = getattr(base_val, 'tzinfo', None)
19741990
break
19751991

19761992
for j in range(i, n):
19771993
# Compare val's timezone with the reference timezone
19781994
# NaT can coexist with tz-aware datetimes, so skip if encountered
19791995
val = values[j]
1980-
if val is not NaT:
1996+
if val is not NaT and val is not None and not util.is_nan(val):
19811997
tz = getattr(val, 'tzinfo', None)
19821998
if not tz_compare(base_tz, tz):
19831999
return False
19842000

2001+
# Note: we should only be called if a tzaware datetime has been seen,
2002+
# so base_tz should always be set at this point.
19852003
return True
19862004

19872005

@@ -2464,6 +2482,7 @@ def maybe_convert_objects(ndarray[object] objects,
24642482
except OutOfBoundsTimedelta:
24652483
seen.object_ = True
24662484
break
2485+
break
24672486
else:
24682487
seen.object_ = True
24692488
break
@@ -2546,6 +2565,32 @@ def maybe_convert_objects(ndarray[object] objects,
25462565
return dti._data
25472566
seen.object_ = True
25482567

2568+
elif seen.datetime_:
2569+
if is_datetime_or_datetime64_array(objects):
2570+
from pandas import DatetimeIndex
2571+
2572+
try:
2573+
dti = DatetimeIndex(objects)
2574+
except OutOfBoundsDatetime:
2575+
pass
2576+
else:
2577+
# unbox to ndarray[datetime64[ns]]
2578+
return dti._data._ndarray
2579+
seen.object_ = True
2580+
2581+
elif seen.timedelta_:
2582+
if is_timedelta_or_timedelta64_array(objects):
2583+
from pandas import TimedeltaIndex
2584+
2585+
try:
2586+
tdi = TimedeltaIndex(objects)
2587+
except OutOfBoundsTimedelta:
2588+
pass
2589+
else:
2590+
# unbox to ndarray[timedelta64[ns]]
2591+
return tdi._data._ndarray
2592+
seen.object_ = True
2593+
25492594
if seen.period_:
25502595
if is_period_array(objects):
25512596
from pandas import PeriodIndex

pandas/core/indexes/base.py

Lines changed: 11 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6310,91 +6310,19 @@ def _maybe_cast_data_without_dtype(subarr: np.ndarray) -> ArrayLike:
63106310
-------
63116311
np.ndarray or ExtensionArray
63126312
"""
6313-
# Runtime import needed bc IntervalArray imports Index
6314-
from pandas.core.arrays import (
6315-
DatetimeArray,
6316-
IntervalArray,
6317-
PeriodArray,
6318-
TimedeltaArray,
6319-
)
6320-
6321-
assert subarr.dtype == object, subarr.dtype
6322-
inferred = lib.infer_dtype(subarr, skipna=False)
6323-
6324-
if inferred == "integer":
6325-
try:
6326-
data = _try_convert_to_int_array(subarr)
6327-
return data
6328-
except ValueError:
6329-
pass
63306313

6314+
result = lib.maybe_convert_objects(
6315+
subarr,
6316+
convert_datetime=True,
6317+
convert_timedelta=True,
6318+
convert_period=True,
6319+
convert_interval=True,
6320+
dtype_if_all_nat=np.dtype("datetime64[ns]"),
6321+
)
6322+
if result.dtype.kind in ["b", "c"]:
63316323
return subarr
6332-
6333-
elif inferred in ["floating", "mixed-integer-float", "integer-na"]:
6334-
# TODO: Returns IntegerArray for integer-na case in the future
6335-
data = np.asarray(subarr).astype(np.float64, copy=False)
6336-
return data
6337-
6338-
elif inferred == "interval":
6339-
ia_data = IntervalArray._from_sequence(subarr, copy=False)
6340-
return ia_data
6341-
elif inferred == "boolean":
6342-
# don't support boolean explicitly ATM
6343-
pass
6344-
elif inferred != "string":
6345-
if inferred.startswith("datetime"):
6346-
try:
6347-
data = DatetimeArray._from_sequence(subarr, copy=False)
6348-
return data
6349-
except (ValueError, OutOfBoundsDatetime):
6350-
# GH 27011
6351-
# If we have mixed timezones, just send it
6352-
# down the base constructor
6353-
pass
6354-
6355-
elif inferred.startswith("timedelta"):
6356-
tda = TimedeltaArray._from_sequence(subarr, copy=False)
6357-
return tda
6358-
elif inferred == "period":
6359-
parr = PeriodArray._from_sequence(subarr)
6360-
return parr
6361-
6362-
return subarr
6363-
6364-
6365-
def _try_convert_to_int_array(data: np.ndarray) -> np.ndarray:
6366-
"""
6367-
Attempt to convert an array of data into an integer array.
6368-
6369-
Parameters
6370-
----------
6371-
data : np.ndarray[object]
6372-
6373-
Returns
6374-
-------
6375-
int_array : data converted to either an ndarray[int64] or ndarray[uint64]
6376-
6377-
Raises
6378-
------
6379-
ValueError if the conversion was not successful.
6380-
"""
6381-
try:
6382-
res = data.astype("i8", copy=False)
6383-
if (res == data).all():
6384-
return res
6385-
except (OverflowError, TypeError, ValueError):
6386-
pass
6387-
6388-
# Conversion to int64 failed (possibly due to overflow),
6389-
# so let's try now with uint64.
6390-
try:
6391-
res = data.astype("u8", copy=False)
6392-
if (res == data).all():
6393-
return res
6394-
except (OverflowError, TypeError, ValueError):
6395-
pass
6396-
6397-
raise ValueError
6324+
result = ensure_wrapped_if_datetimelike(result)
6325+
return result
63986326

63996327

64006328
def get_unanimous_names(*indexes: Index) -> tuple[Hashable, ...]:

pandas/tests/base/test_value_counts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def test_value_counts_datetime64(index_or_series):
242242
expected_s = pd.concat([Series([4], index=DatetimeIndex([pd.NaT])), expected_s])
243243
tm.assert_series_equal(result, expected_s)
244244

245+
assert s.dtype == "datetime64[ns]"
245246
unique = s.unique()
246247
assert unique.dtype == "datetime64[ns]"
247248

pandas/tests/dtypes/test_inference.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from decimal import Decimal
1515
from fractions import Fraction
1616
from io import StringIO
17+
import itertools
1718
from numbers import Number
1819
import re
1920

@@ -658,8 +659,9 @@ def test_maybe_convert_objects_datetime(self):
658659
)
659660
tm.assert_numpy_array_equal(out, exp)
660661

662+
# with convert_timedelta=True, the nan is a valid NA value for td64
661663
arr = np.array([np.timedelta64(1, "s"), np.nan], dtype=object)
662-
exp = arr.copy()
664+
exp = exp[::-1]
663665
out = lib.maybe_convert_objects(
664666
arr, convert_datetime=True, convert_timedelta=True
665667
)
@@ -716,6 +718,16 @@ def test_maybe_convert_objects_datetime_overflow_safe(self, dtype):
716718
# no OutOfBoundsDatetime/OutOfBoundsTimedeltas
717719
tm.assert_numpy_array_equal(out, arr)
718720

721+
def test_maybe_convert_objects_mixed_datetimes(self):
722+
ts = Timestamp("now")
723+
vals = [ts, ts.to_pydatetime(), ts.to_datetime64(), pd.NaT, np.nan, None]
724+
725+
for data in itertools.permutations(vals):
726+
data = np.array(list(data), dtype=object)
727+
expected = DatetimeIndex(data)._data._ndarray
728+
result = lib.maybe_convert_objects(data, convert_datetime=True)
729+
tm.assert_numpy_array_equal(result, expected)
730+
719731
def test_maybe_convert_objects_timedelta64_nat(self):
720732
obj = np.timedelta64("NaT", "ns")
721733
arr = np.array([obj], dtype=object)

pandas/tests/indexes/test_index_new.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ def test_constructor_mixed_nat_objs_infers_object(self, swap_objs):
138138
tm.assert_index_equal(Index(data), expected)
139139
tm.assert_index_equal(Index(np.array(data, dtype=object)), expected)
140140

141+
@pytest.mark.parametrize("swap_objs", [True, False])
142+
def test_constructor_datetime_and_datetime64(self, swap_objs):
143+
data = [Timestamp(2021, 6, 8, 9, 42), np.datetime64("now")]
144+
if swap_objs:
145+
data = data[::-1]
146+
expected = DatetimeIndex(data)
147+
148+
tm.assert_index_equal(Index(data), expected)
149+
tm.assert_index_equal(Index(np.array(data, dtype=object)), expected)
150+
141151

142152
class TestDtypeEnforced:
143153
# check we don't silently ignore the dtype keyword

0 commit comments

Comments
 (0)