Skip to content

Commit b4375a4

Browse files
authored
REF: avoid unnecessary casting in algorithms (#41256)
1 parent feac84c commit b4375a4

File tree

3 files changed

+55
-70
lines changed

3 files changed

+55
-70
lines changed

pandas/core/algorithms.py

+44-60
Original file line numberDiff line numberDiff line change
@@ -37,31 +37,26 @@
3737
from pandas.core.dtypes.cast import (
3838
construct_1d_object_array_from_listlike,
3939
infer_dtype_from_array,
40+
sanitize_to_nanoseconds,
4041
)
4142
from pandas.core.dtypes.common import (
4243
ensure_float64,
43-
ensure_int64,
4444
ensure_object,
4545
ensure_platform_int,
46-
ensure_uint64,
4746
is_array_like,
4847
is_bool_dtype,
4948
is_categorical_dtype,
5049
is_complex_dtype,
5150
is_datetime64_dtype,
52-
is_datetime64_ns_dtype,
5351
is_extension_array_dtype,
5452
is_float_dtype,
5553
is_integer,
5654
is_integer_dtype,
5755
is_list_like,
5856
is_numeric_dtype,
5957
is_object_dtype,
60-
is_period_dtype,
6158
is_scalar,
62-
is_signed_integer_dtype,
6359
is_timedelta64_dtype,
64-
is_unsigned_integer_dtype,
6560
needs_i8_conversion,
6661
pandas_dtype,
6762
)
@@ -134,71 +129,49 @@ def _ensure_data(values: ArrayLike) -> tuple[np.ndarray, DtypeObj]:
134129
values = extract_array(values, extract_numpy=True)
135130

136131
# we check some simple dtypes first
137-
if is_object_dtype(values):
132+
if is_object_dtype(values.dtype):
138133
return ensure_object(np.asarray(values)), np.dtype("object")
139134

140-
try:
141-
if is_bool_dtype(values):
142-
# we are actually coercing to uint64
143-
# until our algos support uint8 directly (see TODO)
144-
return np.asarray(values).astype("uint64"), np.dtype("bool")
145-
elif is_signed_integer_dtype(values):
146-
return ensure_int64(values), values.dtype
147-
elif is_unsigned_integer_dtype(values):
148-
return ensure_uint64(values), values.dtype
149-
elif is_float_dtype(values):
135+
elif is_bool_dtype(values.dtype):
136+
if isinstance(values, np.ndarray):
137+
# i.e. actually dtype == np.dtype("bool")
138+
return np.asarray(values).view("uint8"), values.dtype
139+
else:
140+
# i.e. all-bool Categorical, BooleanArray
141+
return np.asarray(values).astype("uint8", copy=False), values.dtype
142+
143+
elif is_integer_dtype(values.dtype):
144+
return np.asarray(values), values.dtype
145+
146+
elif is_float_dtype(values.dtype):
147+
# Note: checking `values.dtype == "float128"` raises on Windows and 32bit
148+
# error: Item "ExtensionDtype" of "Union[Any, ExtensionDtype, dtype[Any]]"
149+
# has no attribute "itemsize"
150+
if values.dtype.itemsize in [2, 12, 16]: # type: ignore[union-attr]
151+
# we dont (yet) have float128 hashtable support
150152
return ensure_float64(values), values.dtype
151-
elif is_complex_dtype(values):
152-
153-
# ignore the fact that we are casting to float
154-
# which discards complex parts
155-
with catch_warnings():
156-
simplefilter("ignore", np.ComplexWarning)
157-
values = ensure_float64(values)
158-
return values, np.dtype("float64")
153+
return np.asarray(values), values.dtype
159154

160-
except (TypeError, ValueError, OverflowError):
161-
# if we are trying to coerce to a dtype
162-
# and it is incompatible this will fall through to here
163-
return ensure_object(values), np.dtype("object")
155+
elif is_complex_dtype(values.dtype):
156+
# ignore the fact that we are casting to float
157+
# which discards complex parts
158+
with catch_warnings():
159+
simplefilter("ignore", np.ComplexWarning)
160+
values = ensure_float64(values)
161+
return values, np.dtype("float64")
164162

165163
# datetimelike
166-
if needs_i8_conversion(values.dtype):
167-
if is_period_dtype(values.dtype):
168-
from pandas import PeriodIndex
169-
170-
values = PeriodIndex(values)._data
171-
elif is_timedelta64_dtype(values.dtype):
172-
from pandas import TimedeltaIndex
173-
174-
values = TimedeltaIndex(values)._data
175-
else:
176-
# Datetime
177-
if values.ndim > 1 and is_datetime64_ns_dtype(values.dtype):
178-
# Avoid calling the DatetimeIndex constructor as it is 1D only
179-
# Note: this is reached by DataFrame.rank calls GH#27027
180-
# TODO(EA2D): special case not needed with 2D EAs
181-
asi8 = values.view("i8")
182-
dtype = values.dtype
183-
# error: Incompatible return value type (got "Tuple[Any,
184-
# Union[dtype, ExtensionDtype, None]]", expected
185-
# "Tuple[ndarray, Union[dtype, ExtensionDtype]]")
186-
return asi8, dtype # type: ignore[return-value]
187-
188-
from pandas import DatetimeIndex
189-
190-
values = DatetimeIndex(values)._data
191-
dtype = values.dtype
192-
return values.asi8, dtype
164+
elif needs_i8_conversion(values.dtype):
165+
if isinstance(values, np.ndarray):
166+
values = sanitize_to_nanoseconds(values)
167+
npvalues = values.view("i8")
168+
npvalues = cast(np.ndarray, npvalues)
169+
return npvalues, values.dtype
193170

194171
elif is_categorical_dtype(values.dtype):
195172
values = cast("Categorical", values)
196173
values = values.codes
197174
dtype = pandas_dtype("category")
198-
199-
# we are actually coercing to int64
200-
# until our algos support int* directly (not all do)
201-
values = ensure_int64(values)
202175
return values, dtype
203176

204177
# we have failed, return object
@@ -268,8 +241,15 @@ def _ensure_arraylike(values) -> ArrayLike:
268241

269242
_hashtables = {
270243
"float64": htable.Float64HashTable,
244+
"float32": htable.Float32HashTable,
271245
"uint64": htable.UInt64HashTable,
246+
"uint32": htable.UInt32HashTable,
247+
"uint16": htable.UInt16HashTable,
248+
"uint8": htable.UInt8HashTable,
272249
"int64": htable.Int64HashTable,
250+
"int32": htable.Int32HashTable,
251+
"int16": htable.Int16HashTable,
252+
"int8": htable.Int8HashTable,
273253
"string": htable.StringHashTable,
274254
"object": htable.PyObjectHashTable,
275255
}
@@ -298,6 +278,10 @@ def _get_values_for_rank(values: ArrayLike) -> np.ndarray:
298278
values = cast("Categorical", values)._values_for_rank()
299279

300280
values, _ = _ensure_data(values)
281+
if values.dtype.kind in ["i", "u", "f"]:
282+
# rank_t includes only object, int64, uint64, float64
283+
dtype = values.dtype.kind + "8"
284+
values = values.astype(dtype, copy=False)
301285
return values
302286

303287

pandas/core/arrays/sparse/array.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def _from_factorized(cls, values, original):
550550
# Data
551551
# ------------------------------------------------------------------------
552552
@property
553-
def sp_index(self):
553+
def sp_index(self) -> SparseIndex:
554554
"""
555555
The SparseIndex containing the location of non- ``fill_value`` points.
556556
"""
@@ -570,7 +570,7 @@ def sp_values(self) -> np.ndarray:
570570
return self._sparse_values
571571

572572
@property
573-
def dtype(self):
573+
def dtype(self) -> SparseDtype:
574574
return self._dtype
575575

576576
@property
@@ -597,7 +597,7 @@ def kind(self) -> str:
597597
return "block"
598598

599599
@property
600-
def _valid_sp_values(self):
600+
def _valid_sp_values(self) -> np.ndarray:
601601
sp_vals = self.sp_values
602602
mask = notna(sp_vals)
603603
return sp_vals[mask]
@@ -620,7 +620,7 @@ def nbytes(self) -> int:
620620
return self.sp_values.nbytes + self.sp_index.nbytes
621621

622622
@property
623-
def density(self):
623+
def density(self) -> float:
624624
"""
625625
The percent of non- ``fill_value`` points, as decimal.
626626

pandas/tests/test_algos.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1756,14 +1756,15 @@ def _check(arr):
17561756
_check(np.array([np.nan, np.nan, 5.0, 5.0, 5.0, np.nan, 1, 2, 3, np.nan]))
17571757
_check(np.array([4.0, np.nan, 5.0, 5.0, 5.0, np.nan, 1, 2, 4.0, np.nan]))
17581758

1759-
def test_basic(self, writable):
1759+
@pytest.mark.parametrize("dtype", np.typecodes["AllInteger"])
1760+
def test_basic(self, writable, dtype):
17601761
exp = np.array([1, 2], dtype=np.float64)
17611762

1762-
for dtype in np.typecodes["AllInteger"]:
1763-
data = np.array([1, 100], dtype=dtype)
1764-
data.setflags(write=writable)
1765-
s = Series(data)
1766-
tm.assert_numpy_array_equal(algos.rank(s), exp)
1763+
data = np.array([1, 100], dtype=dtype)
1764+
data.setflags(write=writable)
1765+
ser = Series(data)
1766+
result = algos.rank(ser)
1767+
tm.assert_numpy_array_equal(result, exp)
17671768

17681769
def test_uint64_overflow(self):
17691770
exp = np.array([1, 2], dtype=np.float64)

0 commit comments

Comments
 (0)