Skip to content

Commit 1307353

Browse files
authored
PERF: numpy dtype checks (#52582)
1 parent a6d5db7 commit 1307353

20 files changed

+92
-89
lines changed

pandas/_libs/lib.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ NoDefault = Literal[_NoDefault.no_default]
3636
i8max: int
3737
u8max: int
3838

39+
def is_np_dtype(dtype: object, kinds: str | None = ...) -> bool: ...
3940
def item_from_zerodim(val: object) -> object: ...
4041
def infer_dtype(value: object, skipna: bool = ...) -> str: ...
4142
def is_iterator(obj: object) -> bool: ...

pandas/_libs/lib.pyx

+27
Original file line numberDiff line numberDiff line change
@@ -3070,3 +3070,30 @@ def dtypes_all_equal(list types not None) -> bool:
30703070
return False
30713071
else:
30723072
return True
3073+
3074+
3075+
def is_np_dtype(object dtype, str kinds=None) -> bool:
3076+
"""
3077+
Optimized check for `isinstance(dtype, np.dtype)` with
3078+
optional `and dtype.kind in kinds`.
3079+
3080+
dtype = np.dtype("m8[ns]")
3081+
3082+
In [7]: %timeit isinstance(dtype, np.dtype)
3083+
117 ns ± 1.91 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
3084+
3085+
In [8]: %timeit is_np_dtype(dtype)
3086+
64 ns ± 1.51 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
3087+
3088+
In [9]: %timeit is_timedelta64_dtype(dtype)
3089+
209 ns ± 6.96 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
3090+
3091+
In [10]: %timeit is_np_dtype(dtype, "m")
3092+
93.4 ns ± 1.11 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
3093+
"""
3094+
if not cnp.PyArray_DescrCheck(dtype):
3095+
# i.e. not isinstance(dtype, np.dtype)
3096+
return False
3097+
if kinds is None:
3098+
return True
3099+
return dtype.kind in kinds

pandas/core/arrays/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@
4040

4141
from pandas.core.dtypes.cast import maybe_cast_to_extension_array
4242
from pandas.core.dtypes.common import (
43-
is_datetime64_dtype,
4443
is_dtype_equal,
4544
is_list_like,
4645
is_scalar,
47-
is_timedelta64_dtype,
4846
pandas_dtype,
4947
)
5048
from pandas.core.dtypes.dtypes import ExtensionDtype
@@ -582,12 +580,12 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
582580
cls = dtype.construct_array_type()
583581
return cls._from_sequence(self, dtype=dtype, copy=copy)
584582

585-
elif is_datetime64_dtype(dtype):
583+
elif lib.is_np_dtype(dtype, "M"):
586584
from pandas.core.arrays import DatetimeArray
587585

588586
return DatetimeArray._from_sequence(self, dtype=dtype, copy=copy)
589587

590-
elif is_timedelta64_dtype(dtype):
588+
elif lib.is_np_dtype(dtype, "m"):
591589
from pandas.core.arrays import TimedeltaArray
592590

593591
return TimedeltaArray._from_sequence(self, dtype=dtype, copy=copy)

pandas/core/arrays/categorical.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,13 @@
3838
ensure_platform_int,
3939
is_any_real_numeric_dtype,
4040
is_bool_dtype,
41-
is_datetime64_dtype,
4241
is_dict_like,
4342
is_dtype_equal,
4443
is_extension_array_dtype,
4544
is_hashable,
4645
is_integer_dtype,
4746
is_list_like,
4847
is_scalar,
49-
is_timedelta64_dtype,
5048
needs_i8_conversion,
5149
pandas_dtype,
5250
)
@@ -622,9 +620,9 @@ def _from_inferred_categories(
622620
# Convert to a specialized type with `dtype` if specified.
623621
if is_any_real_numeric_dtype(dtype.categories):
624622
cats = to_numeric(inferred_categories, errors="coerce")
625-
elif is_datetime64_dtype(dtype.categories):
623+
elif lib.is_np_dtype(dtype.categories.dtype, "M"):
626624
cats = to_datetime(inferred_categories, errors="coerce")
627-
elif is_timedelta64_dtype(dtype.categories):
625+
elif lib.is_np_dtype(dtype.categories.dtype, "m"):
628626
cats = to_timedelta(inferred_categories, errors="coerce")
629627
elif is_bool_dtype(dtype.categories):
630628
if true_values is None:

pandas/core/arrays/datetimelike.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,13 @@
8484
from pandas.core.dtypes.common import (
8585
is_all_strings,
8686
is_datetime64_any_dtype,
87-
is_datetime64_dtype,
8887
is_datetime_or_timedelta_dtype,
8988
is_dtype_equal,
9089
is_float_dtype,
9190
is_integer_dtype,
9291
is_list_like,
9392
is_object_dtype,
9493
is_string_dtype,
95-
is_timedelta64_dtype,
9694
pandas_dtype,
9795
)
9896
from pandas.core.dtypes.dtypes import (
@@ -993,7 +991,7 @@ def _get_arithmetic_result_freq(self, other) -> BaseOffset | None:
993991

994992
@final
995993
def _add_datetimelike_scalar(self, other) -> DatetimeArray:
996-
if not is_timedelta64_dtype(self.dtype):
994+
if not lib.is_np_dtype(self.dtype, "m"):
997995
raise TypeError(
998996
f"cannot add {type(self).__name__} and {type(other).__name__}"
999997
)
@@ -1029,7 +1027,7 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
10291027

10301028
@final
10311029
def _add_datetime_arraylike(self, other: DatetimeArray) -> DatetimeArray:
1032-
if not is_timedelta64_dtype(self.dtype):
1030+
if not lib.is_np_dtype(self.dtype, "m"):
10331031
raise TypeError(
10341032
f"cannot add {type(self).__name__} and {type(other).__name__}"
10351033
)
@@ -1093,7 +1091,7 @@ def _sub_datetimelike(self, other: Timestamp | DatetimeArray) -> TimedeltaArray:
10931091

10941092
@final
10951093
def _add_period(self, other: Period) -> PeriodArray:
1096-
if not is_timedelta64_dtype(self.dtype):
1094+
if not lib.is_np_dtype(self.dtype, "m"):
10971095
raise TypeError(f"cannot add Period to a {type(self).__name__}")
10981096

10991097
# We will wrap in a PeriodArray and defer to the reversed operation
@@ -1294,7 +1292,7 @@ def __add__(self, other):
12941292
result = self._add_offset(other)
12951293
elif isinstance(other, (datetime, np.datetime64)):
12961294
result = self._add_datetimelike_scalar(other)
1297-
elif isinstance(other, Period) and is_timedelta64_dtype(self.dtype):
1295+
elif isinstance(other, Period) and lib.is_np_dtype(self.dtype, "m"):
12981296
result = self._add_period(other)
12991297
elif lib.is_integer(other):
13001298
# This check must come after the check for np.timedelta64
@@ -1305,13 +1303,13 @@ def __add__(self, other):
13051303
result = obj._addsub_int_array_or_scalar(other * obj.dtype._n, operator.add)
13061304

13071305
# array-like others
1308-
elif is_timedelta64_dtype(other_dtype):
1306+
elif lib.is_np_dtype(other_dtype, "m"):
13091307
# TimedeltaIndex, ndarray[timedelta64]
13101308
result = self._add_timedelta_arraylike(other)
13111309
elif is_object_dtype(other_dtype):
13121310
# e.g. Array/Index of DateOffset objects
13131311
result = self._addsub_object_array(other, operator.add)
1314-
elif is_datetime64_dtype(other_dtype) or isinstance(
1312+
elif lib.is_np_dtype(other_dtype, "M") or isinstance(
13151313
other_dtype, DatetimeTZDtype
13161314
):
13171315
# DatetimeIndex, ndarray[datetime64]
@@ -1329,7 +1327,7 @@ def __add__(self, other):
13291327
# In remaining cases, this will end up raising TypeError.
13301328
return NotImplemented
13311329

1332-
if isinstance(result, np.ndarray) and is_timedelta64_dtype(result.dtype):
1330+
if isinstance(result, np.ndarray) and lib.is_np_dtype(result.dtype, "m"):
13331331
from pandas.core.arrays import TimedeltaArray
13341332

13351333
return TimedeltaArray(result)
@@ -1366,13 +1364,13 @@ def __sub__(self, other):
13661364
result = self._sub_periodlike(other)
13671365

13681366
# array-like others
1369-
elif is_timedelta64_dtype(other_dtype):
1367+
elif lib.is_np_dtype(other_dtype, "m"):
13701368
# TimedeltaIndex, ndarray[timedelta64]
13711369
result = self._add_timedelta_arraylike(-other)
13721370
elif is_object_dtype(other_dtype):
13731371
# e.g. Array/Index of DateOffset objects
13741372
result = self._addsub_object_array(other, operator.sub)
1375-
elif is_datetime64_dtype(other_dtype) or isinstance(
1373+
elif lib.is_np_dtype(other_dtype, "M") or isinstance(
13761374
other_dtype, DatetimeTZDtype
13771375
):
13781376
# DatetimeIndex, ndarray[datetime64]
@@ -1389,7 +1387,7 @@ def __sub__(self, other):
13891387
# Includes ExtensionArrays, float_dtype
13901388
return NotImplemented
13911389

1392-
if isinstance(result, np.ndarray) and is_timedelta64_dtype(result.dtype):
1390+
if isinstance(result, np.ndarray) and lib.is_np_dtype(result.dtype, "m"):
13931391
from pandas.core.arrays import TimedeltaArray
13941392

13951393
return TimedeltaArray(result)
@@ -1398,7 +1396,7 @@ def __sub__(self, other):
13981396
def __rsub__(self, other):
13991397
other_dtype = getattr(other, "dtype", None)
14001398

1401-
if is_datetime64_any_dtype(other_dtype) and is_timedelta64_dtype(self.dtype):
1399+
if is_datetime64_any_dtype(other_dtype) and lib.is_np_dtype(self.dtype, "m"):
14021400
# ndarray[datetime64] cannot be subtracted from self, so
14031401
# we need to wrap in DatetimeArray/Index and flip the operation
14041402
if lib.is_scalar(other):
@@ -1420,10 +1418,10 @@ def __rsub__(self, other):
14201418
raise TypeError(
14211419
f"cannot subtract {type(self).__name__} from {type(other).__name__}"
14221420
)
1423-
elif isinstance(self.dtype, PeriodDtype) and is_timedelta64_dtype(other_dtype):
1421+
elif isinstance(self.dtype, PeriodDtype) and lib.is_np_dtype(other_dtype, "m"):
14241422
# TODO: Can we simplify/generalize these cases at all?
14251423
raise TypeError(f"cannot subtract {type(self).__name__} from {other.dtype}")
1426-
elif is_timedelta64_dtype(self.dtype):
1424+
elif lib.is_np_dtype(self.dtype, "m"):
14271425
self = cast("TimedeltaArray", self)
14281426
return (-self) + other
14291427

pandas/core/arrays/datetimes.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
is_object_dtype,
5858
is_sparse,
5959
is_string_dtype,
60-
is_timedelta64_dtype,
6160
pandas_dtype,
6261
)
6362
from pandas.core.dtypes.dtypes import (
@@ -670,7 +669,7 @@ def astype(self, dtype, copy: bool = True):
670669

671670
elif (
672671
self.tz is None
673-
and is_datetime64_dtype(dtype)
672+
and lib.is_np_dtype(dtype, "M")
674673
and not is_unitless(dtype)
675674
and is_supported_unit(get_unit_from_dtype(dtype))
676675
):
@@ -679,7 +678,7 @@ def astype(self, dtype, copy: bool = True):
679678
return type(self)._simple_new(res_values, dtype=res_values.dtype)
680679
# TODO: preserve freq?
681680

682-
elif self.tz is not None and is_datetime64_dtype(dtype):
681+
elif self.tz is not None and lib.is_np_dtype(dtype, "M"):
683682
# pre-2.0 behavior for DTA/DTI was
684683
# values.tz_convert("UTC").tz_localize(None), which did not match
685684
# the Series behavior
@@ -691,7 +690,7 @@ def astype(self, dtype, copy: bool = True):
691690

692691
elif (
693692
self.tz is None
694-
and is_datetime64_dtype(dtype)
693+
and lib.is_np_dtype(dtype, "M")
695694
and dtype != self.dtype
696695
and is_unitless(dtype)
697696
):
@@ -2083,7 +2082,7 @@ def _sequence_to_dt64ns(
20832082
tz = _maybe_infer_tz(tz, data.tz)
20842083
result = data._ndarray
20852084

2086-
elif is_datetime64_dtype(data_dtype):
2085+
elif lib.is_np_dtype(data_dtype, "M"):
20872086
# tz-naive DatetimeArray or ndarray[datetime64]
20882087
data = getattr(data, "_ndarray", data)
20892088
new_dtype = data.dtype
@@ -2242,7 +2241,7 @@ def maybe_convert_dtype(data, copy: bool, tz: tzinfo | None = None):
22422241
data = data.astype(DT64NS_DTYPE).view("i8")
22432242
copy = False
22442243

2245-
elif is_timedelta64_dtype(data.dtype) or is_bool_dtype(data.dtype):
2244+
elif lib.is_np_dtype(data.dtype, "m") or is_bool_dtype(data.dtype):
22462245
# GH#29794 enforcing deprecation introduced in GH#23539
22472246
raise TypeError(f"dtype {data.dtype} cannot be converted to datetime64[ns]")
22482247
elif isinstance(data.dtype, PeriodDtype):
@@ -2391,7 +2390,7 @@ def _validate_tz_from_dtype(
23912390
raise ValueError("Cannot pass both a timezone-aware dtype and tz=None")
23922391
tz = dtz
23932392

2394-
if tz is not None and is_datetime64_dtype(dtype):
2393+
if tz is not None and lib.is_np_dtype(dtype, "M"):
23952394
# We also need to check for the case where the user passed a
23962395
# tz-naive dtype (i.e. datetime64[ns])
23972396
if tz is not None and not timezones.tz_compare(tz, dtz):

pandas/core/arrays/timedeltas.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __mul__(self, other) -> Self:
484484
if not hasattr(other, "dtype"):
485485
# list, tuple
486486
other = np.array(other)
487-
if len(other) != len(self) and not is_timedelta64_dtype(other.dtype):
487+
if len(other) != len(self) and not lib.is_np_dtype(other.dtype, "m"):
488488
# Exclude timedelta64 here so we correctly raise TypeError
489489
# for that instead of ValueError
490490
raise ValueError("Cannot multiply with unequal lengths")
@@ -585,7 +585,7 @@ def __truediv__(self, other):
585585

586586
other = self._cast_divlike_op(other)
587587
if (
588-
is_timedelta64_dtype(other.dtype)
588+
lib.is_np_dtype(other.dtype, "m")
589589
or is_integer_dtype(other.dtype)
590590
or is_float_dtype(other.dtype)
591591
):
@@ -613,7 +613,7 @@ def __rtruediv__(self, other):
613613
return self._scalar_divlike_op(other, op)
614614

615615
other = self._cast_divlike_op(other)
616-
if is_timedelta64_dtype(other.dtype):
616+
if lib.is_np_dtype(other.dtype, "m"):
617617
return self._vector_divlike_op(other, op)
618618

619619
elif is_object_dtype(other.dtype):
@@ -634,7 +634,7 @@ def __floordiv__(self, other):
634634

635635
other = self._cast_divlike_op(other)
636636
if (
637-
is_timedelta64_dtype(other.dtype)
637+
lib.is_np_dtype(other.dtype, "m")
638638
or is_integer_dtype(other.dtype)
639639
or is_float_dtype(other.dtype)
640640
):
@@ -662,7 +662,7 @@ def __rfloordiv__(self, other):
662662
return self._scalar_divlike_op(other, op)
663663

664664
other = self._cast_divlike_op(other)
665-
if is_timedelta64_dtype(other.dtype):
665+
if lib.is_np_dtype(other.dtype, "m"):
666666
return self._vector_divlike_op(other, op)
667667

668668
elif is_object_dtype(other.dtype):
@@ -940,7 +940,7 @@ def sequence_to_td64ns(
940940
data[mask] = iNaT
941941
copy = False
942942

943-
elif is_timedelta64_dtype(data.dtype):
943+
elif lib.is_np_dtype(data.dtype, "m"):
944944
data_unit = get_unit_from_dtype(data.dtype)
945945
if not is_supported_unit(data_unit):
946946
# cast to closest supported unit, i.e. s or ns

pandas/core/dtypes/astype.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@
1818
from pandas.errors import IntCastingNaNError
1919

2020
from pandas.core.dtypes.common import (
21-
is_datetime64_dtype,
2221
is_dtype_equal,
2322
is_integer_dtype,
2423
is_object_dtype,
2524
is_string_dtype,
26-
is_timedelta64_dtype,
2725
pandas_dtype,
2826
)
2927
from pandas.core.dtypes.dtypes import (
@@ -108,14 +106,14 @@ def _astype_nansafe(
108106
# if we have a datetime/timedelta array of objects
109107
# then coerce to datetime64[ns] and use DatetimeArray.astype
110108

111-
if is_datetime64_dtype(dtype):
109+
if lib.is_np_dtype(dtype, "M"):
112110
from pandas import to_datetime
113111

114112
dti = to_datetime(arr.ravel())
115113
dta = dti._data.reshape(arr.shape)
116114
return dta.astype(dtype, copy=False)._ndarray
117115

118-
elif is_timedelta64_dtype(dtype):
116+
elif lib.is_np_dtype(dtype, "m"):
119117
from pandas.core.construction import ensure_wrapped_if_datetimelike
120118

121119
# bc we know arr.dtype == object, this is equivalent to

0 commit comments

Comments
 (0)