diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 52606cd7a914e..d7b1741687441 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1090,22 +1090,20 @@ def is_numeric_v_string_like(a: ArrayLike, b) -> bool: ) -def needs_i8_conversion(arr_or_dtype) -> bool: +def needs_i8_conversion(dtype: DtypeObj | None) -> bool: """ - Check whether the array or dtype should be converted to int64. + Check whether the dtype should be converted to int64. - An array-like or dtype "needs" such a conversion if the array-like - or dtype is of a datetime-like dtype + Dtype "needs" such a conversion if the dtype is of a datetime-like dtype Parameters ---------- - arr_or_dtype : array-like or dtype - The array or dtype to check. + dtype : np.dtype, ExtensionDtype, or None Returns ------- boolean - Whether or not the array or dtype should be converted to int64. + Whether or not the dtype should be converted to int64. Examples -------- @@ -1114,30 +1112,27 @@ def needs_i8_conversion(arr_or_dtype) -> bool: >>> needs_i8_conversion(np.int64) False >>> needs_i8_conversion(np.datetime64) + False + >>> needs_i8_conversion(np.dtype(np.datetime64)) True >>> needs_i8_conversion(np.array(['a', 'b'])) False >>> needs_i8_conversion(pd.Series([1, 2])) False >>> needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]")) - True + False >>> needs_i8_conversion(pd.DatetimeIndex([1, 2, 3], tz="US/Eastern")) + False + >>> needs_i8_conversion(pd.DatetimeIndex([1, 2, 3], tz="US/Eastern").dtype) True """ - if arr_or_dtype is None: - return False - if isinstance(arr_or_dtype, np.dtype): - return arr_or_dtype.kind in ["m", "M"] - elif isinstance(arr_or_dtype, ExtensionDtype): - return isinstance(arr_or_dtype, (PeriodDtype, DatetimeTZDtype)) - - try: - dtype = get_dtype(arr_or_dtype) - except (TypeError, ValueError): + if dtype is None: return False if isinstance(dtype, np.dtype): - return dtype.kind in ["m", "M"] - return isinstance(dtype, (PeriodDtype, DatetimeTZDtype)) + return dtype.kind in "mM" + elif isinstance(dtype, ExtensionDtype): + return isinstance(dtype, (PeriodDtype, DatetimeTZDtype)) + return False def is_numeric_dtype(arr_or_dtype) -> bool: diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 6d9b2327ff72e..f9ce0fa1e6ee4 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -960,9 +960,7 @@ def view(self, cls=None): if isinstance(cls, str): dtype = pandas_dtype(cls) - if isinstance(dtype, (np.dtype, ExtensionDtype)) and needs_i8_conversion( - dtype - ): + if needs_i8_conversion(dtype): if dtype.kind == "m" and dtype != "m8[ns]": # e.g. m8[s] return self._data.view(cls) diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 14cf5f317ed5a..38aa0d97f9c8a 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -2015,7 +2015,7 @@ def _get_merge_keys( f"with type {repr(lt.dtype)}" ) - if needs_i8_conversion(lt): + if needs_i8_conversion(getattr(lt, "dtype", None)): if not isinstance(self.tolerance, datetime.timedelta): raise MergeError(msg) if self.tolerance < Timedelta(0): @@ -2101,7 +2101,7 @@ def injection(obj): raise ValueError(f"{side} keys must be sorted") # initial type conversion as needed - if needs_i8_conversion(left_values): + if needs_i8_conversion(getattr(left_values, "dtype", None)): if tolerance is not None: tolerance = Timedelta(tolerance) diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 9c11bff8862c1..0fe8376baeb19 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -518,9 +518,12 @@ def test_needs_i8_conversion(): assert not com.needs_i8_conversion(pd.Series([1, 2])) assert not com.needs_i8_conversion(np.array(["a", "b"])) - assert com.needs_i8_conversion(np.datetime64) - assert com.needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]")) - assert com.needs_i8_conversion(pd.DatetimeIndex(["2000"], tz="US/Eastern")) + assert not com.needs_i8_conversion(np.datetime64) + assert com.needs_i8_conversion(np.dtype(np.datetime64)) + assert not com.needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]")) + assert com.needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]").dtype) + assert not com.needs_i8_conversion(pd.DatetimeIndex(["2000"], tz="US/Eastern")) + assert com.needs_i8_conversion(pd.DatetimeIndex(["2000"], tz="US/Eastern").dtype) def test_is_numeric_dtype():