Skip to content

Commit 37ea2cd

Browse files
authored
PERF: select_dtypes (#42611)
1 parent 86baa9f commit 37ea2cd

File tree

2 files changed

+21
-41
lines changed

2 files changed

+21
-41
lines changed

pandas/core/frame.py

+16-40
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
ColspaceArgType,
5656
CompressionOptions,
5757
Dtype,
58+
DtypeObj,
5859
FilePathOrBuffer,
5960
FillnaOptions,
6061
FloatFormatType,
@@ -4297,50 +4298,25 @@ def check_int_infer_dtype(dtypes):
42974298
if not include.isdisjoint(exclude):
42984299
raise ValueError(f"include and exclude overlap on {(include & exclude)}")
42994300

4300-
# We raise when both include and exclude are empty
4301-
# Hence, we can just shrink the columns we want to keep
4302-
keep_these = np.full(self.shape[1], True)
4303-
4304-
def extract_unique_dtypes_from_dtypes_set(
4305-
dtypes_set: frozenset[Dtype], unique_dtypes: np.ndarray
4306-
) -> list[Dtype]:
4307-
extracted_dtypes = [
4308-
unique_dtype
4309-
for unique_dtype in unique_dtypes
4310-
if (
4311-
issubclass(
4312-
# error: Argument 1 to "tuple" has incompatible type
4313-
# "FrozenSet[Union[ExtensionDtype, Union[str, Any], Type[str],
4314-
# Type[float], Type[int], Type[complex], Type[bool],
4315-
# Type[object]]]"; expected "Iterable[Union[type, Tuple[Any,
4316-
# ...]]]"
4317-
unique_dtype.type,
4318-
tuple(dtypes_set), # type: ignore[arg-type]
4319-
)
4320-
or (
4321-
np.number in dtypes_set
4322-
and getattr(unique_dtype, "_is_numeric", False)
4323-
)
4324-
)
4325-
]
4326-
return extracted_dtypes
4301+
def dtype_predicate(dtype: DtypeObj, dtypes_set) -> bool:
4302+
return issubclass(dtype.type, tuple(dtypes_set)) or (
4303+
np.number in dtypes_set and getattr(dtype, "_is_numeric", False)
4304+
)
43274305

4328-
unique_dtypes = self.dtypes.unique()
4306+
def predicate(arr: ArrayLike) -> bool:
4307+
dtype = arr.dtype
4308+
if include:
4309+
if not dtype_predicate(dtype, include):
4310+
return False
43294311

4330-
if include:
4331-
included_dtypes = extract_unique_dtypes_from_dtypes_set(
4332-
include, unique_dtypes
4333-
)
4334-
keep_these &= self.dtypes.isin(included_dtypes)
4312+
if exclude:
4313+
if dtype_predicate(dtype, exclude):
4314+
return False
43354315

4336-
if exclude:
4337-
excluded_dtypes = extract_unique_dtypes_from_dtypes_set(
4338-
exclude, unique_dtypes
4339-
)
4340-
keep_these &= ~self.dtypes.isin(excluded_dtypes)
4316+
return True
43414317

4342-
# error: "ndarray" has no attribute "values"
4343-
return self.iloc[:, keep_these.values] # type: ignore[attr-defined]
4318+
mgr = self._mgr._get_data_subset(predicate)
4319+
return type(self)(mgr).__finalize__(self)
43444320

43454321
def insert(self, loc, column, value, allow_duplicates: bool = False) -> None:
43464322
"""

pandas/core/internals/array_manager.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,11 @@ def _get_data_subset(self: T, predicate: Callable) -> T:
474474
indices = [i for i, arr in enumerate(self.arrays) if predicate(arr)]
475475
arrays = [self.arrays[i] for i in indices]
476476
# TODO copy?
477-
new_axes = [self._axes[0], self._axes[1][np.array(indices, dtype="intp")]]
477+
# Note: using Index.take ensures we can retain e.g. DatetimeIndex.freq,
478+
# see test_describe_datetime_columns
479+
taker = np.array(indices, dtype="intp")
480+
new_cols = self._axes[1].take(taker)
481+
new_axes = [self._axes[0], new_cols]
478482
return type(self)(arrays, new_axes, verify_integrity=False)
479483

480484
def get_bool_data(self: T, copy: bool = False) -> T:

0 commit comments

Comments
 (0)