Skip to content

Commit 87803d0

Browse files
authored
PERF: avoid object dtype cast for Categorical in _ensure_data (#46208)
1 parent c2cf93f commit 87803d0

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

pandas/core/algorithms.py

+13-20
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
136136
# extract_array would raise
137137
values = extract_array(values, extract_numpy=True)
138138

139-
# we check some simple dtypes first
140139
if is_object_dtype(values.dtype):
141140
return ensure_object(np.asarray(values))
142141

@@ -149,17 +148,19 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
149148
return _ensure_data(values._data)
150149
return np.asarray(values)
151150

151+
elif is_categorical_dtype(values.dtype):
152+
# NB: cases that go through here should NOT be using _reconstruct_data
153+
# on the back-end.
154+
values = cast("Categorical", values)
155+
return values.codes
156+
152157
elif is_bool_dtype(values.dtype):
153158
if isinstance(values, np.ndarray):
154159
# i.e. actually dtype == np.dtype("bool")
155160
return np.asarray(values).view("uint8")
156161
else:
157-
# i.e. all-bool Categorical, BooleanArray
158-
try:
159-
return np.asarray(values).astype("uint8", copy=False)
160-
except (TypeError, ValueError):
161-
# GH#42107 we have pd.NAs present
162-
return np.asarray(values)
162+
# e.g. Sparse[bool, False] # TODO: no test cases get here
163+
return np.asarray(values).astype("uint8", copy=False)
163164

164165
elif is_integer_dtype(values.dtype):
165166
return np.asarray(values)
@@ -174,10 +175,7 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
174175
return np.asarray(values)
175176

176177
elif is_complex_dtype(values.dtype):
177-
# Incompatible return value type (got "Tuple[Union[Any, ExtensionArray,
178-
# ndarray[Any, Any]], Union[Any, ExtensionDtype]]", expected
179-
# "Tuple[ndarray[Any, Any], Union[dtype[Any], ExtensionDtype]]")
180-
return values # type: ignore[return-value]
178+
return cast(np.ndarray, values)
181179

182180
# datetimelike
183181
elif needs_i8_conversion(values.dtype):
@@ -187,11 +185,6 @@ def _ensure_data(values: ArrayLike) -> np.ndarray:
187185
npvalues = cast(np.ndarray, npvalues)
188186
return npvalues
189187

190-
elif is_categorical_dtype(values.dtype):
191-
values = cast("Categorical", values)
192-
values = values.codes
193-
return values
194-
195188
# we have failed, return object
196189
values = np.asarray(values, dtype=object)
197190
return ensure_object(values)
@@ -218,7 +211,8 @@ def _reconstruct_data(
218211
return values
219212

220213
if not isinstance(dtype, np.dtype):
221-
# i.e. ExtensionDtype
214+
# i.e. ExtensionDtype; note we have ruled out above the possibility
215+
# that values.dtype == dtype
222216
cls = dtype.construct_array_type()
223217

224218
values = cls._from_sequence(values, dtype=dtype)
@@ -938,9 +932,8 @@ def mode(values: ArrayLike, dropna: bool = True) -> ArrayLike:
938932
if needs_i8_conversion(values.dtype):
939933
# Got here with ndarray; dispatch to DatetimeArray/TimedeltaArray.
940934
values = ensure_wrapped_if_datetimelike(values)
941-
# error: Item "ndarray[Any, Any]" of "Union[ExtensionArray,
942-
# ndarray[Any, Any]]" has no attribute "_mode"
943-
return values._mode(dropna=dropna) # type: ignore[union-attr]
935+
values = cast("ExtensionArray", values)
936+
return values._mode(dropna=dropna)
944937

945938
values = _ensure_data(values)
946939

0 commit comments

Comments
 (0)