Skip to content

Commit 86ec487

Browse files
datajankoproost
authored andcommitted
[PERF] Vectorize select_dtypes (pandas-dev#28447)
1 parent b4943cf commit 86ec487

File tree

3 files changed

+40
-36
lines changed

3 files changed

+40
-36
lines changed

asv_bench/benchmarks/frame_methods.py

+11
Original file line numberDiff line numberDiff line change
@@ -609,4 +609,15 @@ def time_dataframe_describe(self):
609609
self.df.describe()
610610

611611

612+
class SelectDtypes:
613+
params = [100, 1000]
614+
param_names = ["n"]
615+
616+
def setup(self, n):
617+
self.df = DataFrame(np.random.randn(10, n))
618+
619+
def time_select_dtypes(self, n):
620+
self.df.select_dtypes(include="int")
621+
622+
612623
from .pandas_vb_common import setup # noqa: F401 isort:skip

doc/source/whatsnew/v1.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ Performance improvements
151151
- Performance improvement in :func:`cut` when ``bins`` is an :class:`IntervalIndex` (:issue:`27668`)
152152
- Performance improvement in :meth:`DataFrame.corr` when ``method`` is ``"spearman"`` (:issue:`28139`)
153153
- Performance improvement in :meth:`DataFrame.replace` when provided a list of values to replace (:issue:`28099`)
154+
- Performance improvement in :meth:`DataFrame.select_dtypes` by using vectorization instead of iterating over a loop (:issue:`28317`)
154155

155156
.. _whatsnew_1000.bug_fixes:
156157

pandas/core/frame.py

+28-36
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"""
1111
import collections
1212
from collections import OrderedDict, abc
13-
import functools
1413
from io import StringIO
1514
import itertools
1615
import sys
@@ -3395,15 +3394,6 @@ def select_dtypes(self, include=None, exclude=None):
33953394
5 False 2.0
33963395
"""
33973396

3398-
def _get_info_slice(obj, indexer):
3399-
"""Slice the info axis of `obj` with `indexer`."""
3400-
if not hasattr(obj, "_info_axis_number"):
3401-
msg = "object of type {typ!r} has no info axis"
3402-
raise TypeError(msg.format(typ=type(obj).__name__))
3403-
slices = [slice(None)] * obj.ndim
3404-
slices[obj._info_axis_number] = indexer
3405-
return tuple(slices)
3406-
34073397
if not is_list_like(include):
34083398
include = (include,) if include is not None else ()
34093399
if not is_list_like(exclude):
@@ -3428,33 +3418,35 @@ def _get_info_slice(obj, indexer):
34283418
)
34293419
)
34303420

3431-
# empty include/exclude -> defaults to True
3432-
# three cases (we've already raised if both are empty)
3433-
# case 1: empty include, nonempty exclude
3434-
# we have True, True, ... True for include, same for exclude
3435-
# in the loop below we get the excluded
3436-
# and when we call '&' below we get only the excluded
3437-
# case 2: nonempty include, empty exclude
3438-
# same as case 1, but with include
3439-
# case 3: both nonempty
3440-
# the "union" of the logic of case 1 and case 2:
3441-
# we get the included and excluded, and return their logical and
3442-
include_these = Series(not bool(include), index=self.columns)
3443-
exclude_these = Series(not bool(exclude), index=self.columns)
3444-
3445-
def is_dtype_instance_mapper(idx, dtype):
3446-
return idx, functools.partial(issubclass, dtype.type)
3447-
3448-
for idx, f in itertools.starmap(
3449-
is_dtype_instance_mapper, enumerate(self.dtypes)
3450-
):
3451-
if include: # checks for the case of empty include or exclude
3452-
include_these.iloc[idx] = any(map(f, include))
3453-
if exclude:
3454-
exclude_these.iloc[idx] = not any(map(f, exclude))
3421+
# We raise when both include and exclude are empty
3422+
# Hence, we can just shrink the columns we want to keep
3423+
keep_these = np.full(self.shape[1], True)
3424+
3425+
def extract_unique_dtypes_from_dtypes_set(
3426+
dtypes_set: FrozenSet[Dtype], unique_dtypes: np.ndarray
3427+
) -> List[Dtype]:
3428+
extracted_dtypes = [
3429+
unique_dtype
3430+
for unique_dtype in unique_dtypes
3431+
if issubclass(unique_dtype.type, tuple(dtypes_set)) # type: ignore
3432+
]
3433+
return extracted_dtypes
3434+
3435+
unique_dtypes = self.dtypes.unique()
3436+
3437+
if include:
3438+
included_dtypes = extract_unique_dtypes_from_dtypes_set(
3439+
include, unique_dtypes
3440+
)
3441+
keep_these &= self.dtypes.isin(included_dtypes)
3442+
3443+
if exclude:
3444+
excluded_dtypes = extract_unique_dtypes_from_dtypes_set(
3445+
exclude, unique_dtypes
3446+
)
3447+
keep_these &= ~self.dtypes.isin(excluded_dtypes)
34553448

3456-
dtype_indexer = include_these & exclude_these
3457-
return self.loc[_get_info_slice(self, dtype_indexer)]
3449+
return self.iloc[:, keep_these.values]
34583450

34593451
def insert(self, loc, column, value, allow_duplicates=False):
34603452
"""

0 commit comments

Comments
 (0)