Skip to content

REF: simplify indexes.base._maybe_cast_data_without_dtype #41881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,21 @@ cpdef bint is_datetime64_array(ndarray values):
return validator.validate(values)


@cython.internal
cdef class AnyDatetimeValidator(DatetimeValidator):
cdef inline bint is_value_typed(self, object value) except -1:
return util.is_datetime64_object(value) or (
PyDateTime_Check(value) and value.tzinfo is None
)


cdef bint is_datetime_or_datetime64_array(ndarray values):
cdef:
AnyDatetimeValidator validator = AnyDatetimeValidator(len(values),
skipna=True)
return validator.validate(values)


# Note: only python-exposed for tests
def is_datetime_with_singletz_array(values: ndarray) -> bool:
"""
Expand All @@ -1966,22 +1981,25 @@ def is_datetime_with_singletz_array(values: ndarray) -> bool:

if n == 0:
return False

# Get a reference timezone to compare with the rest of the tzs in the array
for i in range(n):
base_val = values[i]
if base_val is not NaT:
if base_val is not NaT and base_val is not None and not util.is_nan(base_val):
base_tz = getattr(base_val, 'tzinfo', None)
break

for j in range(i, n):
# Compare val's timezone with the reference timezone
# NaT can coexist with tz-aware datetimes, so skip if encountered
val = values[j]
if val is not NaT:
if val is not NaT and val is not None and not util.is_nan(val):
tz = getattr(val, 'tzinfo', None)
if not tz_compare(base_tz, tz):
return False

# Note: we should only be called if a tzaware datetime has been seen,
# so base_tz should always be set at this point.
return True


Expand Down Expand Up @@ -2464,6 +2482,7 @@ def maybe_convert_objects(ndarray[object] objects,
except OutOfBoundsTimedelta:
seen.object_ = True
break
break
else:
seen.object_ = True
break
Expand Down Expand Up @@ -2546,6 +2565,32 @@ def maybe_convert_objects(ndarray[object] objects,
return dti._data
seen.object_ = True

elif seen.datetime_:
if is_datetime_or_datetime64_array(objects):
from pandas import DatetimeIndex

try:
dti = DatetimeIndex(objects)
except OutOfBoundsDatetime:
pass
else:
# unbox to ndarray[datetime64[ns]]
return dti._data._ndarray
seen.object_ = True

elif seen.timedelta_:
if is_timedelta_or_timedelta64_array(objects):
from pandas import TimedeltaIndex

try:
tdi = TimedeltaIndex(objects)
except OutOfBoundsTimedelta:
pass
else:
# unbox to ndarray[timedelta64[ns]]
return tdi._data._ndarray
seen.object_ = True

if seen.period_:
if is_period_array(objects):
from pandas import PeriodIndex
Expand Down
94 changes: 11 additions & 83 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6310,91 +6310,19 @@ def _maybe_cast_data_without_dtype(subarr: np.ndarray) -> ArrayLike:
-------
np.ndarray or ExtensionArray
"""
# Runtime import needed bc IntervalArray imports Index
from pandas.core.arrays import (
DatetimeArray,
IntervalArray,
PeriodArray,
TimedeltaArray,
)

assert subarr.dtype == object, subarr.dtype
inferred = lib.infer_dtype(subarr, skipna=False)

if inferred == "integer":
try:
data = _try_convert_to_int_array(subarr)
return data
except ValueError:
pass

result = lib.maybe_convert_objects(
subarr,
convert_datetime=True,
convert_timedelta=True,
convert_period=True,
convert_interval=True,
dtype_if_all_nat=np.dtype("datetime64[ns]"),
)
if result.dtype.kind in ["b", "c"]:
return subarr

elif inferred in ["floating", "mixed-integer-float", "integer-na"]:
# TODO: Returns IntegerArray for integer-na case in the future
data = np.asarray(subarr).astype(np.float64, copy=False)
return data

elif inferred == "interval":
ia_data = IntervalArray._from_sequence(subarr, copy=False)
return ia_data
elif inferred == "boolean":
# don't support boolean explicitly ATM
pass
elif inferred != "string":
if inferred.startswith("datetime"):
try:
data = DatetimeArray._from_sequence(subarr, copy=False)
return data
except (ValueError, OutOfBoundsDatetime):
# GH 27011
# If we have mixed timezones, just send it
# down the base constructor
pass

elif inferred.startswith("timedelta"):
tda = TimedeltaArray._from_sequence(subarr, copy=False)
return tda
elif inferred == "period":
parr = PeriodArray._from_sequence(subarr)
return parr

return subarr


def _try_convert_to_int_array(data: np.ndarray) -> np.ndarray:
"""
Attempt to convert an array of data into an integer array.

Parameters
----------
data : np.ndarray[object]

Returns
-------
int_array : data converted to either an ndarray[int64] or ndarray[uint64]

Raises
------
ValueError if the conversion was not successful.
"""
try:
res = data.astype("i8", copy=False)
if (res == data).all():
return res
except (OverflowError, TypeError, ValueError):
pass

# Conversion to int64 failed (possibly due to overflow),
# so let's try now with uint64.
try:
res = data.astype("u8", copy=False)
if (res == data).all():
return res
except (OverflowError, TypeError, ValueError):
pass

raise ValueError
result = ensure_wrapped_if_datetimelike(result)
return result


def get_unanimous_names(*indexes: Index) -> tuple[Hashable, ...]:
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/base/test_value_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def test_value_counts_datetime64(index_or_series):
expected_s = pd.concat([Series([4], index=DatetimeIndex([pd.NaT])), expected_s])
tm.assert_series_equal(result, expected_s)

assert s.dtype == "datetime64[ns]"
unique = s.unique()
assert unique.dtype == "datetime64[ns]"

Expand Down
14 changes: 13 additions & 1 deletion pandas/tests/dtypes/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from decimal import Decimal
from fractions import Fraction
from io import StringIO
import itertools
from numbers import Number
import re

Expand Down Expand Up @@ -658,8 +659,9 @@ def test_maybe_convert_objects_datetime(self):
)
tm.assert_numpy_array_equal(out, exp)

# with convert_timedelta=True, the nan is a valid NA value for td64
arr = np.array([np.timedelta64(1, "s"), np.nan], dtype=object)
exp = arr.copy()
exp = exp[::-1]
out = lib.maybe_convert_objects(
arr, convert_datetime=True, convert_timedelta=True
)
Expand Down Expand Up @@ -716,6 +718,16 @@ def test_maybe_convert_objects_datetime_overflow_safe(self, dtype):
# no OutOfBoundsDatetime/OutOfBoundsTimedeltas
tm.assert_numpy_array_equal(out, arr)

def test_maybe_convert_objects_mixed_datetimes(self):
ts = Timestamp("now")
vals = [ts, ts.to_pydatetime(), ts.to_datetime64(), pd.NaT, np.nan, None]

for data in itertools.permutations(vals):
data = np.array(list(data), dtype=object)
expected = DatetimeIndex(data)._data._ndarray
result = lib.maybe_convert_objects(data, convert_datetime=True)
tm.assert_numpy_array_equal(result, expected)

def test_maybe_convert_objects_timedelta64_nat(self):
obj = np.timedelta64("NaT", "ns")
arr = np.array([obj], dtype=object)
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/indexes/test_index_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def test_constructor_mixed_nat_objs_infers_object(self, swap_objs):
tm.assert_index_equal(Index(data), expected)
tm.assert_index_equal(Index(np.array(data, dtype=object)), expected)

@pytest.mark.parametrize("swap_objs", [True, False])
def test_constructor_datetime_and_datetime64(self, swap_objs):
data = [Timestamp(2021, 6, 8, 9, 42), np.datetime64("now")]
if swap_objs:
data = data[::-1]
expected = DatetimeIndex(data)

tm.assert_index_equal(Index(data), expected)
tm.assert_index_equal(Index(np.array(data, dtype=object)), expected)


class TestDtypeEnforced:
# check we don't silently ignore the dtype keyword
Expand Down