Skip to content

Commit 8e19396

Browse files
authored
PERF: dtype checks (#52506)
1 parent 92f837f commit 8e19396

File tree

9 files changed

+49
-54
lines changed

9 files changed

+49
-54
lines changed

pandas/core/dtypes/missing.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,12 @@
2626
TD64NS_DTYPE,
2727
ensure_object,
2828
is_bool_dtype,
29-
is_complex_dtype,
3029
is_dtype_equal,
3130
is_extension_array_dtype,
32-
is_float_dtype,
3331
is_integer_dtype,
3432
is_object_dtype,
3533
is_scalar,
3634
is_string_or_object_np_dtype,
37-
needs_i8_conversion,
3835
)
3936
from pandas.core.dtypes.dtypes import (
4037
CategoricalDtype,
@@ -291,7 +288,7 @@ def _isna_array(values: ArrayLike, inf_as_na: bool = False):
291288
result = values.isna() # type: ignore[assignment]
292289
elif is_string_or_object_np_dtype(values.dtype):
293290
result = _isna_string_dtype(values, inf_as_na=inf_as_na)
294-
elif needs_i8_conversion(dtype):
291+
elif dtype.kind in "mM":
295292
# this is the NaT pattern
296293
result = values.view("i8") == iNaT
297294
else:
@@ -502,7 +499,7 @@ def array_equivalent(
502499
# fastpath when we require that the dtypes match (Block.equals)
503500
if left.dtype.kind in "fc":
504501
return _array_equivalent_float(left, right)
505-
elif needs_i8_conversion(left.dtype):
502+
elif left.dtype.kind in "mM":
506503
return _array_equivalent_datetimelike(left, right)
507504
elif is_string_or_object_np_dtype(left.dtype):
508505
# TODO: fastpath for pandas' StringDtype
@@ -519,14 +516,14 @@ def array_equivalent(
519516
return _array_equivalent_object(left, right, strict_nan)
520517

521518
# NaNs can occur in float and complex arrays.
522-
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):
519+
if left.dtype.kind in "fc":
523520
if not (left.size and right.size):
524521
return True
525522
return ((left == right) | (isna(left) & isna(right))).all()
526523

527-
elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype):
524+
elif left.dtype.kind in "mM" or right.dtype.kind in "mM":
528525
# datetime64, timedelta64, Period
529-
if not is_dtype_equal(left.dtype, right.dtype):
526+
if left.dtype != right.dtype:
530527
return False
531528

532529
left = left.view("i8")
@@ -541,11 +538,11 @@ def array_equivalent(
541538
return np.array_equal(left, right)
542539

543540

544-
def _array_equivalent_float(left, right) -> bool:
541+
def _array_equivalent_float(left: np.ndarray, right: np.ndarray) -> bool:
545542
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())
546543

547544

548-
def _array_equivalent_datetimelike(left, right):
545+
def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray):
549546
return np.array_equal(left.view("i8"), right.view("i8"))
550547

551548

@@ -601,7 +598,7 @@ def infer_fill_value(val):
601598
if not is_list_like(val):
602599
val = [val]
603600
val = np.array(val, copy=False)
604-
if needs_i8_conversion(val.dtype):
601+
if val.dtype.kind in "mM":
605602
return np.array("NaT", dtype=val.dtype)
606603
elif is_object_dtype(val.dtype):
607604
dtype = lib.infer_dtype(ensure_object(val), skipna=False)
@@ -616,7 +613,7 @@ def maybe_fill(arr: np.ndarray) -> np.ndarray:
616613
"""
617614
Fill numpy.ndarray with NaN, unless we have a integer or boolean dtype.
618615
"""
619-
if arr.dtype.kind not in ("u", "i", "b"):
616+
if arr.dtype.kind not in "iub":
620617
arr.fill(np.nan)
621618
return arr
622619

@@ -650,15 +647,15 @@ def na_value_for_dtype(dtype: DtypeObj, compat: bool = True):
650647

651648
if isinstance(dtype, ExtensionDtype):
652649
return dtype.na_value
653-
elif needs_i8_conversion(dtype):
650+
elif dtype.kind in "mM":
654651
return dtype.type("NaT", "ns")
655-
elif is_float_dtype(dtype):
652+
elif dtype.kind == "f":
656653
return np.nan
657-
elif is_integer_dtype(dtype):
654+
elif dtype.kind in "iu":
658655
if compat:
659656
return 0
660657
return np.nan
661-
elif is_bool_dtype(dtype):
658+
elif dtype.kind == "b":
662659
if compat:
663660
return False
664661
return np.nan

pandas/core/frame.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
is_dataclass,
9595
is_dict_like,
9696
is_dtype_equal,
97-
is_extension_array_dtype,
9897
is_float,
9998
is_float_dtype,
10099
is_hashable,
@@ -3597,7 +3596,9 @@ def transpose(self, *args, copy: bool = False) -> DataFrame:
35973596
result._mgr.add_references(self._mgr) # type: ignore[arg-type]
35983597

35993598
elif (
3600-
self._is_homogeneous_type and dtypes and is_extension_array_dtype(dtypes[0])
3599+
self._is_homogeneous_type
3600+
and dtypes
3601+
and isinstance(dtypes[0], ExtensionDtype)
36013602
):
36023603
# We have EAs with the same dtype. We can preserve that dtype in transpose.
36033604
dtype = dtypes[0]
@@ -4178,7 +4179,7 @@ def _set_item(self, key, value) -> None:
41784179
if (
41794180
key in self.columns
41804181
and value.ndim == 1
4181-
and not is_extension_array_dtype(value)
4182+
and not isinstance(value.dtype, ExtensionDtype)
41824183
):
41834184
# broadcast across multiple columns if necessary
41844185
if not self.columns.is_unique or isinstance(self.columns, MultiIndex):

pandas/core/generic.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@
123123
is_timedelta64_dtype,
124124
pandas_dtype,
125125
)
126-
from pandas.core.dtypes.dtypes import DatetimeTZDtype
126+
from pandas.core.dtypes.dtypes import (
127+
DatetimeTZDtype,
128+
ExtensionDtype,
129+
)
127130
from pandas.core.dtypes.generic import (
128131
ABCDataFrame,
129132
ABCSeries,
@@ -4670,7 +4673,7 @@ def _drop_axis(
46704673
if errors == "raise" and labels_missing:
46714674
raise KeyError(f"{labels} not found in axis")
46724675

4673-
if is_extension_array_dtype(mask.dtype):
4676+
if isinstance(mask.dtype, ExtensionDtype):
46744677
# GH#45860
46754678
mask = mask.to_numpy(dtype=bool)
46764679

@@ -5458,7 +5461,7 @@ def _needs_reindex_multi(self, axes, method, level: Level | None) -> bool_t:
54585461
and not (
54595462
self.ndim == 2
54605463
and len(self.dtypes) == 1
5461-
and is_extension_array_dtype(self.dtypes.iloc[0])
5464+
and isinstance(self.dtypes.iloc[0], ExtensionDtype)
54625465
)
54635466
)
54645467

pandas/core/groupby/ops.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@
4747
ensure_platform_int,
4848
ensure_uint64,
4949
is_1d_only_ea_dtype,
50-
is_bool_dtype,
51-
is_complex_dtype,
52-
is_float_dtype,
53-
is_integer_dtype,
54-
is_numeric_dtype,
55-
needs_i8_conversion,
5650
)
5751
from pandas.core.dtypes.missing import (
5852
isna,
@@ -248,7 +242,7 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
248242
if how == "rank":
249243
out_dtype = "float64"
250244
else:
251-
if is_numeric_dtype(dtype):
245+
if dtype.kind in "iufcb":
252246
out_dtype = f"{dtype.kind}{dtype.itemsize}"
253247
else:
254248
out_dtype = "object"
@@ -274,9 +268,9 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
274268
if dtype == np.dtype(bool):
275269
return np.dtype(np.int64)
276270
elif how in ["mean", "median", "var", "std", "sem"]:
277-
if is_float_dtype(dtype) or is_complex_dtype(dtype):
271+
if dtype.kind in "fc":
278272
return dtype
279-
elif is_numeric_dtype(dtype):
273+
elif dtype.kind in "iub":
280274
return np.dtype(np.float64)
281275
return dtype
282276

@@ -339,14 +333,14 @@ def _call_cython_op(
339333
orig_values = values
340334

341335
dtype = values.dtype
342-
is_numeric = is_numeric_dtype(dtype)
336+
is_numeric = dtype.kind in "iufcb"
343337

344-
is_datetimelike = needs_i8_conversion(dtype)
338+
is_datetimelike = dtype.kind in "mM"
345339

346340
if is_datetimelike:
347341
values = values.view("int64")
348342
is_numeric = True
349-
elif is_bool_dtype(dtype):
343+
elif dtype.kind == "b":
350344
values = values.view("uint8")
351345
if values.dtype == "float16":
352346
values = values.astype(np.float32)
@@ -446,7 +440,7 @@ def _call_cython_op(
446440
# i.e. counts is defined. Locations where count<min_count
447441
# need to have the result set to np.nan, which may require casting,
448442
# see GH#40767
449-
if is_integer_dtype(result.dtype) and not is_datetimelike:
443+
if result.dtype.kind in "iu" and not is_datetimelike:
450444
# if the op keeps the int dtypes, we have to use 0
451445
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
452446
empty_groups = counts < cutoff

pandas/core/indexing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from pandas.core.dtypes.common import (
3636
is_array_like,
3737
is_bool_dtype,
38-
is_extension_array_dtype,
3938
is_hashable,
4039
is_integer,
4140
is_iterator,
@@ -46,6 +45,7 @@
4645
is_sequence,
4746
)
4847
from pandas.core.dtypes.concat import concat_compat
48+
from pandas.core.dtypes.dtypes import ExtensionDtype
4949
from pandas.core.dtypes.generic import (
5050
ABCDataFrame,
5151
ABCSeries,
@@ -1128,10 +1128,10 @@ def _validate_key(self, key, axis: Axis):
11281128
# boolean not in slice and with boolean index
11291129
ax = self.obj._get_axis(axis)
11301130
if isinstance(key, bool) and not (
1131-
is_bool_dtype(ax)
1131+
is_bool_dtype(ax.dtype)
11321132
or ax.dtype.name == "boolean"
11331133
or isinstance(ax, MultiIndex)
1134-
and is_bool_dtype(ax.get_level_values(0))
1134+
and is_bool_dtype(ax.get_level_values(0).dtype)
11351135
):
11361136
raise KeyError(
11371137
f"{key}: boolean label can not be used without a boolean index"
@@ -2490,7 +2490,7 @@ def check_bool_indexer(index: Index, key) -> np.ndarray:
24902490
result = result.take(indexer)
24912491

24922492
# fall through for boolean
2493-
if not is_extension_array_dtype(result.dtype):
2493+
if not isinstance(result.dtype, ExtensionDtype):
24942494
return result.astype(bool)._values
24952495

24962496
if is_object_dtype(key):

pandas/core/missing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _interpolate_1d(
455455
# sort preserve_nans and convert to list
456456
preserve_nans = sorted(preserve_nans)
457457

458-
is_datetimelike = needs_i8_conversion(yvalues.dtype)
458+
is_datetimelike = yvalues.dtype.kind in "mM"
459459

460460
if is_datetimelike:
461461
yvalues = yvalues.view("i8")

pandas/core/nanops.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from pandas.util._exceptions import find_stack_level
3535

3636
from pandas.core.dtypes.common import (
37-
is_any_int_dtype,
3837
is_complex,
3938
is_float,
4039
is_float_dtype,
@@ -247,7 +246,7 @@ def _maybe_get_mask(
247246
# Boolean data cannot contain nulls, so signal via mask being None
248247
return None
249248

250-
if skipna or needs_i8_conversion(values.dtype):
249+
if skipna or values.dtype.kind in "mM":
251250
mask = isna(values)
252251

253252
return mask
@@ -300,7 +299,7 @@ def _get_values(
300299
dtype = values.dtype
301300

302301
datetimelike = False
303-
if needs_i8_conversion(values.dtype):
302+
if values.dtype.kind in "mM":
304303
# changing timedelta64/datetime64 to int64 needs to happen after
305304
# finding `mask` above
306305
values = np.asarray(values.view("i8"))
@@ -433,7 +432,7 @@ def _na_for_min_count(values: np.ndarray, axis: AxisInt | None) -> Scalar | np.n
433432
For 2-D values, returns a 1-D array where each element is missing.
434433
"""
435434
# we either return np.nan or pd.NaT
436-
if is_numeric_dtype(values.dtype):
435+
if values.dtype.kind in "iufcb":
437436
values = values.astype("float64")
438437
fill_value = na_value_for_dtype(values.dtype)
439438

@@ -521,7 +520,7 @@ def nanany(
521520
# expected "bool")
522521
return values.any(axis) # type: ignore[return-value]
523522

524-
if needs_i8_conversion(values.dtype) and values.dtype.kind != "m":
523+
if values.dtype.kind == "M":
525524
# GH#34479
526525
warnings.warn(
527526
"'any' with datetime64 dtypes is deprecated and will raise in a "
@@ -582,7 +581,7 @@ def nanall(
582581
# expected "bool")
583582
return values.all(axis) # type: ignore[return-value]
584583

585-
if needs_i8_conversion(values.dtype) and values.dtype.kind != "m":
584+
if values.dtype.kind == "M":
586585
# GH#34479
587586
warnings.warn(
588587
"'all' with datetime64 dtypes is deprecated and will raise in a "
@@ -976,12 +975,12 @@ def nanvar(
976975
"""
977976
dtype = values.dtype
978977
mask = _maybe_get_mask(values, skipna, mask)
979-
if is_any_int_dtype(dtype):
978+
if dtype.kind in "iu":
980979
values = values.astype("f8")
981980
if mask is not None:
982981
values[mask] = np.nan
983982

984-
if is_float_dtype(values.dtype):
983+
if values.dtype.kind == "f":
985984
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype)
986985
else:
987986
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof)
@@ -1007,7 +1006,7 @@ def nanvar(
10071006
# Return variance as np.float64 (the datatype used in the accumulator),
10081007
# unless we were dealing with a float array, in which case use the same
10091008
# precision as the original values array.
1010-
if is_float_dtype(dtype):
1009+
if dtype.kind == "f":
10111010
result = result.astype(dtype, copy=False)
10121011
return result
10131012

pandas/core/series.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
)
6565
from pandas.core.dtypes.common import (
6666
is_dict_like,
67-
is_extension_array_dtype,
6867
is_integer,
6968
is_iterator,
7069
is_list_like,
@@ -73,6 +72,7 @@
7372
pandas_dtype,
7473
validate_all_hashable,
7574
)
75+
from pandas.core.dtypes.dtypes import ExtensionDtype
7676
from pandas.core.dtypes.generic import ABCDataFrame
7777
from pandas.core.dtypes.inference import is_hashable
7878
from pandas.core.dtypes.missing import (
@@ -1861,7 +1861,7 @@ def to_dict(self, into: type[dict] = dict) -> dict:
18611861
# GH16122
18621862
into_c = com.standardize_mapping(into)
18631863

1864-
if is_object_dtype(self) or is_extension_array_dtype(self):
1864+
if is_object_dtype(self.dtype) or isinstance(self.dtype, ExtensionDtype):
18651865
return into_c((k, maybe_box_native(v)) for k, v in self.items())
18661866
else:
18671867
# Not an object dtype => all types will be the same so let the default
@@ -4164,7 +4164,7 @@ def explode(self, ignore_index: bool = False) -> Series:
41644164
3 4
41654165
dtype: object
41664166
"""
4167-
if not len(self) or not is_object_dtype(self):
4167+
if not len(self) or not is_object_dtype(self.dtype):
41684168
result = self.copy()
41694169
return result.reset_index(drop=True) if ignore_index else result
41704170

@@ -5220,7 +5220,7 @@ def _convert_dtypes(
52205220
input_series = self
52215221
if infer_objects:
52225222
input_series = input_series.infer_objects()
5223-
if is_object_dtype(input_series):
5223+
if is_object_dtype(input_series.dtype):
52245224
input_series = input_series.copy(deep=None)
52255225

52265226
if convert_string or convert_integer or convert_boolean or convert_floating:

0 commit comments

Comments
 (0)