Skip to content

PERF: dtype checks #52506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@
TD64NS_DTYPE,
ensure_object,
is_bool_dtype,
is_complex_dtype,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_integer_dtype,
is_object_dtype,
is_scalar,
is_string_or_object_np_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
Expand Down Expand Up @@ -291,7 +288,7 @@ def _isna_array(values: ArrayLike, inf_as_na: bool = False):
result = values.isna() # type: ignore[assignment]
elif is_string_or_object_np_dtype(values.dtype):
result = _isna_string_dtype(values, inf_as_na=inf_as_na)
elif needs_i8_conversion(dtype):
elif dtype.kind in "mM":
# this is the NaT pattern
result = values.view("i8") == iNaT
else:
Expand Down Expand Up @@ -502,7 +499,7 @@ def array_equivalent(
# fastpath when we require that the dtypes match (Block.equals)
if left.dtype.kind in "fc":
return _array_equivalent_float(left, right)
elif needs_i8_conversion(left.dtype):
elif left.dtype.kind in "mM":
return _array_equivalent_datetimelike(left, right)
elif is_string_or_object_np_dtype(left.dtype):
# TODO: fastpath for pandas' StringDtype
Expand All @@ -519,14 +516,14 @@ def array_equivalent(
return _array_equivalent_object(left, right, strict_nan)

# NaNs can occur in float and complex arrays.
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):
if left.dtype.kind in "fc":
if not (left.size and right.size):
return True
return ((left == right) | (isna(left) & isna(right))).all()

elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype):
elif left.dtype.kind in "mM" or right.dtype.kind in "mM":
# datetime64, timedelta64, Period
if not is_dtype_equal(left.dtype, right.dtype):
if left.dtype != right.dtype:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we couldn't do this for np dtypes until we bumped our min version?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we cant do it for comparing numpy dtypes to EA dtypes, but this is safe in places where we have both numpy dtypes

return False

left = left.view("i8")
Expand All @@ -541,11 +538,11 @@ def array_equivalent(
return np.array_equal(left, right)


def _array_equivalent_float(left, right) -> bool:
def _array_equivalent_float(left: np.ndarray, right: np.ndarray) -> bool:
return bool(((left == right) | (np.isnan(left) & np.isnan(right))).all())


def _array_equivalent_datetimelike(left, right):
def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray):
return np.array_equal(left.view("i8"), right.view("i8"))


Expand Down Expand Up @@ -601,7 +598,7 @@ def infer_fill_value(val):
if not is_list_like(val):
val = [val]
val = np.array(val, copy=False)
if needs_i8_conversion(val.dtype):
if val.dtype.kind in "mM":
return np.array("NaT", dtype=val.dtype)
elif is_object_dtype(val.dtype):
dtype = lib.infer_dtype(ensure_object(val), skipna=False)
Expand All @@ -616,7 +613,7 @@ def maybe_fill(arr: np.ndarray) -> np.ndarray:
"""
Fill numpy.ndarray with NaN, unless we have a integer or boolean dtype.
"""
if arr.dtype.kind not in ("u", "i", "b"):
if arr.dtype.kind not in "iub":
arr.fill(np.nan)
return arr

Expand Down Expand Up @@ -650,15 +647,15 @@ def na_value_for_dtype(dtype: DtypeObj, compat: bool = True):

if isinstance(dtype, ExtensionDtype):
return dtype.na_value
elif needs_i8_conversion(dtype):
elif dtype.kind in "mM":
return dtype.type("NaT", "ns")
elif is_float_dtype(dtype):
elif dtype.kind == "f":
return np.nan
elif is_integer_dtype(dtype):
elif dtype.kind in "iu":
if compat:
return 0
return np.nan
elif is_bool_dtype(dtype):
elif dtype.kind == "b":
if compat:
return False
return np.nan
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
is_dataclass,
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
is_float,
is_float_dtype,
is_hashable,
Expand Down Expand Up @@ -3603,7 +3602,9 @@ def transpose(self, *args, copy: bool = False) -> DataFrame:
result._mgr.add_references(self._mgr) # type: ignore[arg-type]

elif (
self._is_homogeneous_type and dtypes and is_extension_array_dtype(dtypes[0])
self._is_homogeneous_type
and dtypes
and isinstance(dtypes[0], ExtensionDtype)
):
# We have EAs with the same dtype. We can preserve that dtype in transpose.
dtype = dtypes[0]
Expand Down Expand Up @@ -4184,7 +4185,7 @@ def _set_item(self, key, value) -> None:
if (
key in self.columns
and value.ndim == 1
and not is_extension_array_dtype(value)
and not isinstance(value.dtype, ExtensionDtype)
):
# broadcast across multiple columns if necessary
if not self.columns.is_unique or isinstance(self.columns, MultiIndex):
Expand Down
9 changes: 6 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -4670,7 +4673,7 @@ def _drop_axis(
if errors == "raise" and labels_missing:
raise KeyError(f"{labels} not found in axis")

if is_extension_array_dtype(mask.dtype):
if isinstance(mask.dtype, ExtensionDtype):
# GH#45860
mask = mask.to_numpy(dtype=bool)

Expand Down Expand Up @@ -5457,7 +5460,7 @@ def _needs_reindex_multi(self, axes, method, level: Level | None) -> bool_t:
and not (
self.ndim == 2
and len(self.dtypes) == 1
and is_extension_array_dtype(self.dtypes.iloc[0])
and isinstance(self.dtypes.iloc[0], ExtensionDtype)
)
)

Expand Down
20 changes: 7 additions & 13 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@
ensure_platform_int,
ensure_uint64,
is_1d_only_ea_dtype,
is_bool_dtype,
is_complex_dtype,
is_float_dtype,
is_integer_dtype,
is_numeric_dtype,
needs_i8_conversion,
)
from pandas.core.dtypes.missing import (
isna,
Expand Down Expand Up @@ -248,7 +242,7 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
if how == "rank":
out_dtype = "float64"
else:
if is_numeric_dtype(dtype):
if dtype.kind in "iufcb":
out_dtype = f"{dtype.kind}{dtype.itemsize}"
else:
out_dtype = "object"
Expand All @@ -274,9 +268,9 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
if dtype == np.dtype(bool):
return np.dtype(np.int64)
elif how in ["mean", "median", "var", "std", "sem"]:
if is_float_dtype(dtype) or is_complex_dtype(dtype):
if dtype.kind in "fc":
return dtype
elif is_numeric_dtype(dtype):
elif dtype.kind in "iub":
return np.dtype(np.float64)
return dtype

Expand Down Expand Up @@ -339,14 +333,14 @@ def _call_cython_op(
orig_values = values

dtype = values.dtype
is_numeric = is_numeric_dtype(dtype)
is_numeric = dtype.kind in "iufcb"

is_datetimelike = needs_i8_conversion(dtype)
is_datetimelike = dtype.kind in "mM"

if is_datetimelike:
values = values.view("int64")
is_numeric = True
elif is_bool_dtype(dtype):
elif dtype.kind == "b":
values = values.view("uint8")
if values.dtype == "float16":
values = values.astype(np.float32)
Expand Down Expand Up @@ -446,7 +440,7 @@ def _call_cython_op(
# i.e. counts is defined. Locations where count<min_count
# need to have the result set to np.nan, which may require casting,
# see GH#40767
if is_integer_dtype(result.dtype) and not is_datetimelike:
if result.dtype.kind in "iu" and not is_datetimelike:
# if the op keeps the int dtypes, we have to use 0
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
empty_groups = counts < cutoff
Expand Down
8 changes: 4 additions & 4 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from pandas.core.dtypes.common import (
is_array_like,
is_bool_dtype,
is_extension_array_dtype,
is_hashable,
is_integer,
is_iterator,
Expand All @@ -46,6 +45,7 @@
is_sequence,
)
from pandas.core.dtypes.concat import concat_compat
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCSeries,
Expand Down Expand Up @@ -1128,10 +1128,10 @@ def _validate_key(self, key, axis: Axis):
# boolean not in slice and with boolean index
ax = self.obj._get_axis(axis)
if isinstance(key, bool) and not (
is_bool_dtype(ax)
is_bool_dtype(ax.dtype)
or ax.dtype.name == "boolean"
or isinstance(ax, MultiIndex)
and is_bool_dtype(ax.get_level_values(0))
and is_bool_dtype(ax.get_level_values(0).dtype)
):
raise KeyError(
f"{key}: boolean label can not be used without a boolean index"
Expand Down Expand Up @@ -2490,7 +2490,7 @@ def check_bool_indexer(index: Index, key) -> np.ndarray:
result = result.take(indexer)

# fall through for boolean
if not is_extension_array_dtype(result.dtype):
if not isinstance(result.dtype, ExtensionDtype):
return result.astype(bool)._values

if is_object_dtype(key):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _interpolate_1d(
# sort preserve_nans and convert to list
preserve_nans = sorted(preserve_nans)

is_datetimelike = needs_i8_conversion(yvalues.dtype)
is_datetimelike = yvalues.dtype.kind in "mM"

if is_datetimelike:
yvalues = yvalues.view("i8")
Expand Down
17 changes: 8 additions & 9 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import (
is_any_int_dtype,
is_complex,
is_float,
is_float_dtype,
Expand Down Expand Up @@ -247,7 +246,7 @@ def _maybe_get_mask(
# Boolean data cannot contain nulls, so signal via mask being None
return None

if skipna or needs_i8_conversion(values.dtype):
if skipna or values.dtype.kind in "mM":
mask = isna(values)

return mask
Expand Down Expand Up @@ -300,7 +299,7 @@ def _get_values(
dtype = values.dtype

datetimelike = False
if needs_i8_conversion(values.dtype):
if values.dtype.kind in "mM":
# changing timedelta64/datetime64 to int64 needs to happen after
# finding `mask` above
values = np.asarray(values.view("i8"))
Expand Down Expand Up @@ -433,7 +432,7 @@ def _na_for_min_count(values: np.ndarray, axis: AxisInt | None) -> Scalar | np.n
For 2-D values, returns a 1-D array where each element is missing.
"""
# we either return np.nan or pd.NaT
if is_numeric_dtype(values.dtype):
if values.dtype.kind in "iufcb":
values = values.astype("float64")
fill_value = na_value_for_dtype(values.dtype)

Expand Down Expand Up @@ -521,7 +520,7 @@ def nanany(
# expected "bool")
return values.any(axis) # type: ignore[return-value]

if needs_i8_conversion(values.dtype) and values.dtype.kind != "m":
if values.dtype.kind == "M":
# GH#34479
warnings.warn(
"'any' with datetime64 dtypes is deprecated and will raise in a "
Expand Down Expand Up @@ -582,7 +581,7 @@ def nanall(
# expected "bool")
return values.all(axis) # type: ignore[return-value]

if needs_i8_conversion(values.dtype) and values.dtype.kind != "m":
if values.dtype.kind == "M":
# GH#34479
warnings.warn(
"'all' with datetime64 dtypes is deprecated and will raise in a "
Expand Down Expand Up @@ -976,12 +975,12 @@ def nanvar(
"""
dtype = values.dtype
mask = _maybe_get_mask(values, skipna, mask)
if is_any_int_dtype(dtype):
if dtype.kind in "iu":
values = values.astype("f8")
if mask is not None:
values[mask] = np.nan

if is_float_dtype(values.dtype):
if values.dtype.kind == "f":
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype)
else:
count, d = _get_counts_nanvar(values.shape, mask, axis, ddof)
Expand All @@ -1007,7 +1006,7 @@ def nanvar(
# Return variance as np.float64 (the datatype used in the accumulator),
# unless we were dealing with a float array, in which case use the same
# precision as the original values array.
if is_float_dtype(dtype):
if dtype.kind == "f":
result = result.astype(dtype, copy=False)
return result

Expand Down
8 changes: 4 additions & 4 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
)
from pandas.core.dtypes.common import (
is_dict_like,
is_extension_array_dtype,
is_integer,
is_iterator,
is_list_like,
Expand All @@ -73,6 +72,7 @@
pandas_dtype,
validate_all_hashable,
)
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.inference import is_hashable
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -1872,7 +1872,7 @@ def to_dict(self, into: type[dict] = dict) -> dict:
# GH16122
into_c = com.standardize_mapping(into)

if is_object_dtype(self) or is_extension_array_dtype(self):
if is_object_dtype(self.dtype) or isinstance(self.dtype, ExtensionDtype):
return into_c((k, maybe_box_native(v)) for k, v in self.items())
else:
# Not an object dtype => all types will be the same so let the default
Expand Down Expand Up @@ -4175,7 +4175,7 @@ def explode(self, ignore_index: bool = False) -> Series:
3 4
dtype: object
"""
if not len(self) or not is_object_dtype(self):
if not len(self) or not is_object_dtype(self.dtype):
result = self.copy()
return result.reset_index(drop=True) if ignore_index else result

Expand Down Expand Up @@ -5376,7 +5376,7 @@ def _convert_dtypes(
input_series = self
if infer_objects:
input_series = input_series.infer_objects()
if is_object_dtype(input_series):
if is_object_dtype(input_series.dtype):
input_series = input_series.copy(deep=None)

if convert_string or convert_integer or convert_boolean or convert_floating:
Expand Down
Loading