-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
PERF: reducing dtype checking overhead in groupby #44738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
|
||
from pandas._libs import ( | ||
NaT, | ||
Period, | ||
lib, | ||
) | ||
import pandas._libs.groupby as libgroupby | ||
|
@@ -46,12 +47,10 @@ | |
ensure_int64, | ||
ensure_platform_int, | ||
is_1d_only_ea_obj, | ||
is_bool_dtype, | ||
is_categorical_dtype, | ||
is_complex_dtype, | ||
is_datetime64_any_dtype, | ||
is_float_dtype, | ||
is_integer_dtype, | ||
is_numeric_dtype, | ||
is_sparse, | ||
is_timedelta64_dtype, | ||
|
@@ -262,27 +261,35 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape: | |
out_shape = (ngroups,) + values.shape[1:] | ||
return out_shape | ||
|
||
def get_out_dtype(self, dtype: np.dtype) -> np.dtype: | ||
how = self.how | ||
|
||
# Note: we make this a classmethod and pass kind+how so that caching | ||
# works at the class level and not the instance level | ||
@classmethod | ||
@functools.lru_cache(maxsize=None) | ||
def get_out_dtype(cls, how: str, dtype: np.dtype) -> np.dtype: | ||
if how == "rank": | ||
out_dtype = "float64" | ||
else: | ||
if is_numeric_dtype(dtype): | ||
if dtype.kind in "uifcb": | ||
out_dtype = f"{dtype.kind}{dtype.itemsize}" | ||
else: | ||
out_dtype = "object" | ||
return np.dtype(out_dtype) | ||
|
||
@overload | ||
def _get_result_dtype(self, dtype: np.dtype) -> np.dtype: | ||
@classmethod | ||
def _get_result_dtype(cls, dtype: np.dtype) -> np.dtype: | ||
... # pragma: no cover | ||
|
||
@overload | ||
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype: | ||
@classmethod | ||
def _get_result_dtype(cls, dtype: ExtensionDtype) -> ExtensionDtype: | ||
... # pragma: no cover | ||
|
||
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: | ||
# Note: we make this a classmethod and pass kind+how so that caching | ||
# works at the class level and not the instance level | ||
@classmethod | ||
@functools.lru_cache(maxsize=None) | ||
def _get_result_dtype(cls, how: str, dtype: DtypeObj) -> DtypeObj: | ||
""" | ||
Get the desired dtype of a result based on the | ||
input dtype and how it was computed. | ||
|
@@ -297,8 +304,6 @@ def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: | |
np.dtype or ExtensionDtype | ||
The desired dtype of the result. | ||
""" | ||
how = self.how | ||
|
||
if how in ["add", "cumsum", "sum", "prod"]: | ||
if dtype == np.dtype(bool): | ||
return np.dtype(np.int64) | ||
|
@@ -382,7 +387,7 @@ def _reconstruct_ea_result(self, values, res_values): | |
if isinstance( | ||
values.dtype, (BooleanDtype, _IntegerDtype, FloatingDtype, StringDtype) | ||
): | ||
dtype = self._get_result_dtype(values.dtype) | ||
dtype = self._get_result_dtype(self.how, values.dtype) | ||
cls = dtype.construct_array_type() | ||
return cls._from_sequence(res_values, dtype=dtype) | ||
|
||
|
@@ -422,7 +427,7 @@ def _masked_ea_wrap_cython_operation( | |
**kwargs, | ||
) | ||
|
||
dtype = self._get_result_dtype(orig_values.dtype) | ||
dtype = self._get_result_dtype(self.how, orig_values.dtype) | ||
assert isinstance(dtype, BaseMaskedDtype) | ||
cls = dtype.construct_array_type() | ||
|
||
|
@@ -490,21 +495,26 @@ def _call_cython_op( | |
orig_values = values | ||
|
||
dtype = values.dtype | ||
is_numeric = is_numeric_dtype(dtype) | ||
dtype_kind = dtype.kind | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my own edification, how big a difference does this make? i.e. should i get into the habit of doing this? |
||
# is_numeric_dtype | ||
is_numeric = dtype_kind in "uifcb" | ||
|
||
is_datetimelike = needs_i8_conversion(dtype) | ||
# is_datetimelike = needs_i8_conversion(dtype) | ||
is_datetimelike = dtype_kind in ["m", "M"] or ( | ||
dtype_kind == "O" and dtype.type is Period | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point we have an ndarray, so PeriodDtype shouldn't be possible i think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jorisvandenbossche can you respond to a few small comments here? this should be pretty easy to merge and will be a nice perf bump |
||
) | ||
|
||
if is_datetimelike: | ||
values = values.view("int64") | ||
is_numeric = True | ||
elif is_bool_dtype(dtype): | ||
elif dtype_kind == "b": | ||
values = values.astype("int64") | ||
elif is_integer_dtype(dtype): | ||
elif dtype_kind in "ui": | ||
# e.g. uint8 -> uint64, int16 -> int64 | ||
dtype_str = dtype.kind + "8" | ||
values = values.astype(dtype_str, copy=False) | ||
elif is_numeric: | ||
if not is_complex_dtype(dtype): | ||
if not dtype_kind == "c": | ||
values = ensure_float64(values) | ||
|
||
values = values.T | ||
|
@@ -515,7 +525,7 @@ def _call_cython_op( | |
|
||
out_shape = self._get_output_shape(ngroups, values) | ||
func, values = self.get_cython_func_and_vals(values, is_numeric) | ||
out_dtype = self.get_out_dtype(values.dtype) | ||
out_dtype = self.get_out_dtype(self.how, values.dtype) | ||
|
||
result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) | ||
if self.kind == "aggregate": | ||
|
@@ -562,7 +572,7 @@ def _call_cython_op( | |
# i.e. counts is defined. Locations where count<min_count | ||
# need to have the result set to np.nan, which may require casting, | ||
# see GH#40767 | ||
if is_integer_dtype(result.dtype) and not is_datetimelike: | ||
if result.dtype.kind in "ui" and not is_datetimelike: | ||
cutoff = max(1, min_count) | ||
empty_groups = counts < cutoff | ||
if empty_groups.any(): | ||
|
@@ -575,7 +585,7 @@ def _call_cython_op( | |
if self.how not in self.cast_blocklist: | ||
# e.g. if we are int64 and need to restore to datetime64/timedelta64 | ||
# "rank" is the only member of cast_blocklist we get here | ||
res_dtype = self._get_result_dtype(orig_values.dtype) | ||
res_dtype = self._get_result_dtype(self.how, orig_values.dtype) | ||
op_result = maybe_downcast_to_dtype(result, res_dtype) | ||
else: | ||
op_result = result | ||
|
@@ -608,7 +618,8 @@ def cython_operation( | |
assert axis == 0 | ||
|
||
dtype = values.dtype | ||
is_numeric = is_numeric_dtype(dtype) | ||
# is_numeric_dtype | ||
is_numeric = dtype.kind in "uifcb" | ||
|
||
# can we do this operation with our cython functions | ||
# if not raise NotImplementedError | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment to the effect of "fastpath/not a dtype object"?
will this change how is_numeric_dtype treats EA dtypes?