Skip to content

PERF: numpy dtype checks #52582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ NoDefault = Literal[_NoDefault.no_default]
i8max: int
u8max: int

def is_np_dtype(dtype: object, kinds: str | None = ...) -> bool: ...
def item_from_zerodim(val: object) -> object: ...
def infer_dtype(value: object, skipna: bool = ...) -> str: ...
def is_iterator(obj: object) -> bool: ...
Expand Down
27 changes: 27 additions & 0 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3070,3 +3070,30 @@ def dtypes_all_equal(list types not None) -> bool:
return False
else:
return True


def is_np_dtype(object dtype, str kinds=None) -> bool:
"""
Optimized check for `isinstance(dtype, np.dtype)` with
optional `and dtype.kind in kinds`.

dtype = np.dtype("m8[ns]")

In [7]: %timeit isinstance(dtype, np.dtype)
117 ns ± 1.91 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [8]: %timeit is_np_dtype(dtype)
64 ns ± 1.51 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [9]: %timeit is_timedelta64_dtype(dtype)
209 ns ± 6.96 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [10]: %timeit is_np_dtype(dtype, "m")
93.4 ns ± 1.11 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
"""
if not cnp.PyArray_DescrCheck(dtype):
# i.e. not isinstance(dtype, np.dtype)
return False
if kinds is None:
return True
return dtype.kind in kinds
6 changes: 2 additions & 4 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@

from pandas.core.dtypes.cast import maybe_cast_to_extension_array
from pandas.core.dtypes.common import (
is_datetime64_dtype,
is_dtype_equal,
is_list_like,
is_scalar,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
Expand Down Expand Up @@ -582,12 +580,12 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)

elif is_datetime64_dtype(dtype):
elif lib.is_np_dtype(dtype, "M"):
from pandas.core.arrays import DatetimeArray

return DatetimeArray._from_sequence(self, dtype=dtype, copy=copy)

elif is_timedelta64_dtype(dtype):
elif lib.is_np_dtype(dtype, "m"):
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray._from_sequence(self, dtype=dtype, copy=copy)
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@
ensure_platform_int,
is_any_real_numeric_dtype,
is_bool_dtype,
is_datetime64_dtype,
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
is_hashable,
is_integer_dtype,
is_list_like,
is_scalar,
is_timedelta64_dtype,
needs_i8_conversion,
pandas_dtype,
)
Expand Down Expand Up @@ -622,9 +620,9 @@ def _from_inferred_categories(
# Convert to a specialized type with `dtype` if specified.
if is_any_real_numeric_dtype(dtype.categories):
cats = to_numeric(inferred_categories, errors="coerce")
elif is_datetime64_dtype(dtype.categories):
elif lib.is_np_dtype(dtype.categories.dtype, "M"):
cats = to_datetime(inferred_categories, errors="coerce")
elif is_timedelta64_dtype(dtype.categories):
elif lib.is_np_dtype(dtype.categories.dtype, "m"):
cats = to_timedelta(inferred_categories, errors="coerce")
elif is_bool_dtype(dtype.categories):
if true_values is None:
Expand Down
28 changes: 13 additions & 15 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,13 @@
from pandas.core.dtypes.common import (
is_all_strings,
is_datetime64_any_dtype,
is_datetime64_dtype,
is_datetime_or_timedelta_dtype,
is_dtype_equal,
is_float_dtype,
is_integer_dtype,
is_list_like,
is_object_dtype,
is_string_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -993,7 +991,7 @@ def _get_arithmetic_result_freq(self, other) -> BaseOffset | None:

@final
def _add_datetimelike_scalar(self, other) -> DatetimeArray:
if not is_timedelta64_dtype(self.dtype):
if not lib.is_np_dtype(self.dtype, "m"):
raise TypeError(
f"cannot add {type(self).__name__} and {type(other).__name__}"
)
Expand Down Expand Up @@ -1029,7 +1027,7 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:

@final
def _add_datetime_arraylike(self, other: DatetimeArray) -> DatetimeArray:
if not is_timedelta64_dtype(self.dtype):
if not lib.is_np_dtype(self.dtype, "m"):
raise TypeError(
f"cannot add {type(self).__name__} and {type(other).__name__}"
)
Expand Down Expand Up @@ -1093,7 +1091,7 @@ def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray:

@final
def _add_period(self, other: Period) -> PeriodArray:
if not is_timedelta64_dtype(self.dtype):
if not lib.is_np_dtype(self.dtype, "m"):
raise TypeError(f"cannot add Period to a {type(self).__name__}")

# We will wrap in a PeriodArray and defer to the reversed operation
Expand Down Expand Up @@ -1294,7 +1292,7 @@ def __add__(self, other):
result = self._add_offset(other)
elif isinstance(other, (datetime, np.datetime64)):
result = self._add_datetimelike_scalar(other)
elif isinstance(other, Period) and is_timedelta64_dtype(self.dtype):
elif isinstance(other, Period) and lib.is_np_dtype(self.dtype, "m"):
result = self._add_period(other)
elif lib.is_integer(other):
# This check must come after the check for np.timedelta64
Expand All @@ -1305,13 +1303,13 @@ def __add__(self, other):
result = obj._addsub_int_array_or_scalar(other * obj.dtype._n, operator.add)

# array-like others
elif is_timedelta64_dtype(other_dtype):
elif lib.is_np_dtype(other_dtype, "m"):
# TimedeltaIndex, ndarray[timedelta64]
result = self._add_timedelta_arraylike(other)
elif is_object_dtype(other_dtype):
# e.g. Array/Index of DateOffset objects
result = self._addsub_object_array(other, operator.add)
elif is_datetime64_dtype(other_dtype) or isinstance(
elif lib.is_np_dtype(other_dtype, "M") or isinstance(
other_dtype, DatetimeTZDtype
):
# DatetimeIndex, ndarray[datetime64]
Expand All @@ -1329,7 +1327,7 @@ def __add__(self, other):
# In remaining cases, this will end up raising TypeError.
return NotImplemented

if isinstance(result, np.ndarray) and is_timedelta64_dtype(result.dtype):
if isinstance(result, np.ndarray) and lib.is_np_dtype(result.dtype, "m"):
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray(result)
Expand Down Expand Up @@ -1366,13 +1364,13 @@ def __sub__(self, other):
result = self._sub_periodlike(other)

# array-like others
elif is_timedelta64_dtype(other_dtype):
elif lib.is_np_dtype(other_dtype, "m"):
# TimedeltaIndex, ndarray[timedelta64]
result = self._add_timedelta_arraylike(-other)
elif is_object_dtype(other_dtype):
# e.g. Array/Index of DateOffset objects
result = self._addsub_object_array(other, operator.sub)
elif is_datetime64_dtype(other_dtype) or isinstance(
elif lib.is_np_dtype(other_dtype, "M") or isinstance(
other_dtype, DatetimeTZDtype
):
# DatetimeIndex, ndarray[datetime64]
Expand All @@ -1389,7 +1387,7 @@ def __sub__(self, other):
# Includes ExtensionArrays, float_dtype
return NotImplemented

if isinstance(result, np.ndarray) and is_timedelta64_dtype(result.dtype):
if isinstance(result, np.ndarray) and lib.is_np_dtype(result.dtype, "m"):
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray(result)
Expand All @@ -1398,7 +1396,7 @@ def __sub__(self, other):
def __rsub__(self, other):
other_dtype = getattr(other, "dtype", None)

if is_datetime64_any_dtype(other_dtype) and is_timedelta64_dtype(self.dtype):
if is_datetime64_any_dtype(other_dtype) and lib.is_np_dtype(self.dtype, "m"):
# ndarray[datetime64] cannot be subtracted from self, so
# we need to wrap in DatetimeArray/Index and flip the operation
if lib.is_scalar(other):
Expand All @@ -1420,10 +1418,10 @@ def __rsub__(self, other):
raise TypeError(
f"cannot subtract {type(self).__name__} from {type(other).__name__}"
)
elif isinstance(self.dtype, PeriodDtype) and is_timedelta64_dtype(other_dtype):
elif isinstance(self.dtype, PeriodDtype) and lib.is_np_dtype(other_dtype, "m"):
# TODO: Can we simplify/generalize these cases at all?
raise TypeError(f"cannot subtract {type(self).__name__} from {other.dtype}")
elif is_timedelta64_dtype(self.dtype):
elif lib.is_np_dtype(self.dtype, "m"):
self = cast("TimedeltaArray", self)
return (-self) + other

Expand Down
13 changes: 6 additions & 7 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
is_object_dtype,
is_sparse,
is_string_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -670,7 +669,7 @@ def astype(self, dtype, copy: bool = True):

elif (
self.tz is None
and is_datetime64_dtype(dtype)
and lib.is_np_dtype(dtype, "M")
and not is_unitless(dtype)
and is_supported_unit(get_unit_from_dtype(dtype))
):
Expand All @@ -679,7 +678,7 @@ def astype(self, dtype, copy: bool = True):
return type(self)._simple_new(res_values, dtype=res_values.dtype)
# TODO: preserve freq?

elif self.tz is not None and is_datetime64_dtype(dtype):
elif self.tz is not None and lib.is_np_dtype(dtype, "M"):
# pre-2.0 behavior for DTA/DTI was
# values.tz_convert("UTC").tz_localize(None), which did not match
# the Series behavior
Expand All @@ -691,7 +690,7 @@ def astype(self, dtype, copy: bool = True):

elif (
self.tz is None
and is_datetime64_dtype(dtype)
and lib.is_np_dtype(dtype, "M")
and dtype != self.dtype
and is_unitless(dtype)
):
Expand Down Expand Up @@ -2083,7 +2082,7 @@ def _sequence_to_dt64ns(
tz = _maybe_infer_tz(tz, data.tz)
result = data._ndarray

elif is_datetime64_dtype(data_dtype):
elif lib.is_np_dtype(data_dtype, "M"):
# tz-naive DatetimeArray or ndarray[datetime64]
data = getattr(data, "_ndarray", data)
new_dtype = data.dtype
Expand Down Expand Up @@ -2242,7 +2241,7 @@ def maybe_convert_dtype(data, copy: bool, tz: tzinfo | None = None):
data = data.astype(DT64NS_DTYPE).view("i8")
copy = False

elif is_timedelta64_dtype(data.dtype) or is_bool_dtype(data.dtype):
elif lib.is_np_dtype(data.dtype, "m") or is_bool_dtype(data.dtype):
# GH#29794 enforcing deprecation introduced in GH#23539
raise TypeError(f"dtype {data.dtype} cannot be converted to datetime64[ns]")
elif isinstance(data.dtype, PeriodDtype):
Expand Down Expand Up @@ -2391,7 +2390,7 @@ def _validate_tz_from_dtype(
raise ValueError("Cannot pass both a timezone-aware dtype and tz=None")
tz = dtz

if tz is not None and is_datetime64_dtype(dtype):
if tz is not None and lib.is_np_dtype(dtype, "M"):
# We also need to check for the case where the user passed a
# tz-naive dtype (i.e. datetime64[ns])
if tz is not None and not timezones.tz_compare(tz, dtz):
Expand Down
12 changes: 6 additions & 6 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __mul__(self, other) -> Self:
if not hasattr(other, "dtype"):
# list, tuple
other = np.array(other)
if len(other) != len(self) and not is_timedelta64_dtype(other.dtype):
if len(other) != len(self) and not lib.is_np_dtype(other.dtype, "m"):
# Exclude timedelta64 here so we correctly raise TypeError
# for that instead of ValueError
raise ValueError("Cannot multiply with unequal lengths")
Expand Down Expand Up @@ -585,7 +585,7 @@ def __truediv__(self, other):

other = self._cast_divlike_op(other)
if (
is_timedelta64_dtype(other.dtype)
lib.is_np_dtype(other.dtype, "m")
or is_integer_dtype(other.dtype)
or is_float_dtype(other.dtype)
):
Expand Down Expand Up @@ -613,7 +613,7 @@ def __rtruediv__(self, other):
return self._scalar_divlike_op(other, op)

other = self._cast_divlike_op(other)
if is_timedelta64_dtype(other.dtype):
if lib.is_np_dtype(other.dtype, "m"):
return self._vector_divlike_op(other, op)

elif is_object_dtype(other.dtype):
Expand All @@ -634,7 +634,7 @@ def __floordiv__(self, other):

other = self._cast_divlike_op(other)
if (
is_timedelta64_dtype(other.dtype)
lib.is_np_dtype(other.dtype, "m")
or is_integer_dtype(other.dtype)
or is_float_dtype(other.dtype)
):
Expand Down Expand Up @@ -662,7 +662,7 @@ def __rfloordiv__(self, other):
return self._scalar_divlike_op(other, op)

other = self._cast_divlike_op(other)
if is_timedelta64_dtype(other.dtype):
if lib.is_np_dtype(other.dtype, "m"):
return self._vector_divlike_op(other, op)

elif is_object_dtype(other.dtype):
Expand Down Expand Up @@ -940,7 +940,7 @@ def sequence_to_td64ns(
data[mask] = iNaT
copy = False

elif is_timedelta64_dtype(data.dtype):
elif lib.is_np_dtype(data.dtype, "m"):
data_unit = get_unit_from_dtype(data.dtype)
if not is_supported_unit(data_unit):
# cast to closest supported unit, i.e. s or ns
Expand Down
6 changes: 2 additions & 4 deletions pandas/core/dtypes/astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from pandas.errors import IntCastingNaNError

from pandas.core.dtypes.common import (
is_datetime64_dtype,
is_dtype_equal,
is_integer_dtype,
is_object_dtype,
is_string_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -108,14 +106,14 @@ def _astype_nansafe(
# if we have a datetime/timedelta array of objects
# then coerce to datetime64[ns] and use DatetimeArray.astype

if is_datetime64_dtype(dtype):
if lib.is_np_dtype(dtype, "M"):
from pandas import to_datetime

dti = to_datetime(arr.ravel())
dta = dti._data.reshape(arr.shape)
return dta.astype(dtype, copy=False)._ndarray

elif is_timedelta64_dtype(dtype):
elif lib.is_np_dtype(dtype, "m"):
from pandas.core.construction import ensure_wrapped_if_datetimelike

# bc we know arr.dtype == object, this is equivalent to
Expand Down
Loading