Skip to content

Commit 6788421

Browse files
jrebackanalyticalmonk
authored and
analyticalmonk
committed
PERF: better perf on _ensure_data in core/algorithms, helping perf of unique, duplicated, factorize (pandas-dev#16046)
1 parent 1d2a0c9 commit 6788421

File tree

2 files changed

+84
-37
lines changed

2 files changed

+84
-37
lines changed

pandas/core/algorithms.py

+34-37
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pandas.core.dtypes.common import (
1515
is_unsigned_integer_dtype, is_signed_integer_dtype,
1616
is_integer_dtype, is_complex_dtype,
17+
is_object_dtype,
1718
is_categorical_dtype, is_sparse,
1819
is_period_dtype,
1920
is_numeric_dtype, is_float_dtype,
@@ -63,6 +64,35 @@ def _ensure_data(values, dtype=None):
6364
6465
"""
6566

67+
# we check some simple dtypes first
68+
try:
69+
if is_bool_dtype(values) or is_bool_dtype(dtype):
70+
# we are actually coercing to uint64
71+
# until our algos suppport uint8 directly (see TODO)
72+
return np.asarray(values).astype('uint64'), 'bool', 'uint64'
73+
elif is_signed_integer_dtype(values) or is_signed_integer_dtype(dtype):
74+
return _ensure_int64(values), 'int64', 'int64'
75+
elif (is_unsigned_integer_dtype(values) or
76+
is_unsigned_integer_dtype(dtype)):
77+
return _ensure_uint64(values), 'uint64', 'uint64'
78+
elif is_float_dtype(values) or is_float_dtype(dtype):
79+
return _ensure_float64(values), 'float64', 'float64'
80+
elif is_object_dtype(values) and dtype is None:
81+
return _ensure_object(np.asarray(values)), 'object', 'object'
82+
elif is_complex_dtype(values) or is_complex_dtype(dtype):
83+
84+
# ignore the fact that we are casting to float
85+
# which discards complex parts
86+
with catch_warnings(record=True):
87+
values = _ensure_float64(values)
88+
return values, 'float64', 'float64'
89+
90+
except (TypeError, ValueError):
91+
# if we are trying to coerce to a dtype
92+
# and it is incompat this will fall thru to here
93+
return _ensure_object(values), 'object', 'object'
94+
95+
# datetimelike
6696
if (needs_i8_conversion(values) or
6797
is_period_dtype(dtype) or
6898
is_datetime64_any_dtype(dtype) or
@@ -94,43 +124,9 @@ def _ensure_data(values, dtype=None):
94124

95125
return values, dtype, 'int64'
96126

127+
# we have failed, return object
97128
values = np.asarray(values)
98-
99-
try:
100-
if is_bool_dtype(values) or is_bool_dtype(dtype):
101-
# we are actually coercing to uint64
102-
# until our algos suppport uint8 directly (see TODO)
103-
values = values.astype('uint64')
104-
dtype = 'bool'
105-
ndtype = 'uint64'
106-
elif is_signed_integer_dtype(values) or is_signed_integer_dtype(dtype):
107-
values = _ensure_int64(values)
108-
ndtype = dtype = 'int64'
109-
elif (is_unsigned_integer_dtype(values) or
110-
is_unsigned_integer_dtype(dtype)):
111-
values = _ensure_uint64(values)
112-
ndtype = dtype = 'uint64'
113-
elif is_complex_dtype(values) or is_complex_dtype(dtype):
114-
115-
# ignore the fact that we are casting to float
116-
# which discards complex parts
117-
with catch_warnings(record=True):
118-
values = _ensure_float64(values)
119-
ndtype = dtype = 'float64'
120-
elif is_float_dtype(values) or is_float_dtype(dtype):
121-
values = _ensure_float64(values)
122-
ndtype = dtype = 'float64'
123-
else:
124-
values = _ensure_object(values)
125-
ndtype = dtype = 'object'
126-
127-
except (TypeError, ValueError):
128-
# if we are trying to coerce to a dtype
129-
# and it is incompat this will fall thru to here
130-
values = _ensure_object(values)
131-
ndtype = dtype = 'object'
132-
133-
return values, dtype, ndtype
129+
return _ensure_object(values), 'object', 'object'
134130

135131

136132
def _reconstruct_data(values, dtype, original):
@@ -465,7 +461,7 @@ def safe_sort(values, labels=None, na_sentinel=-1, assume_unique=False):
465461
if not is_list_like(values):
466462
raise TypeError("Only list-like objects are allowed to be passed to"
467463
"safe_sort as values")
468-
values = np.array(values, copy=False)
464+
values = np.asarray(values)
469465

470466
def sort_mixed(values):
471467
# order ints before strings, safe in py3
@@ -547,6 +543,7 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
547543
PeriodIndex
548544
"""
549545

546+
values = _ensure_arraylike(values)
550547
original = values
551548
values, dtype, _ = _ensure_data(values)
552549
(hash_klass, vec_klass), values = _get_data_algo(values, _hashtables)

pandas/core/dtypes/common.py

+50
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def _ensure_categorical(arr):
8282

8383

8484
def is_object_dtype(arr_or_dtype):
85+
if arr_or_dtype is None:
86+
return False
8587
tipo = _get_dtype_type(arr_or_dtype)
8688
return issubclass(tipo, np.object_)
8789

@@ -120,6 +122,8 @@ def is_period(array):
120122

121123

122124
def is_datetime64_dtype(arr_or_dtype):
125+
if arr_or_dtype is None:
126+
return False
123127
try:
124128
tipo = _get_dtype_type(arr_or_dtype)
125129
except TypeError:
@@ -128,23 +132,33 @@ def is_datetime64_dtype(arr_or_dtype):
128132

129133

130134
def is_datetime64tz_dtype(arr_or_dtype):
135+
if arr_or_dtype is None:
136+
return False
131137
return DatetimeTZDtype.is_dtype(arr_or_dtype)
132138

133139

134140
def is_timedelta64_dtype(arr_or_dtype):
141+
if arr_or_dtype is None:
142+
return False
135143
tipo = _get_dtype_type(arr_or_dtype)
136144
return issubclass(tipo, np.timedelta64)
137145

138146

139147
def is_period_dtype(arr_or_dtype):
148+
if arr_or_dtype is None:
149+
return False
140150
return PeriodDtype.is_dtype(arr_or_dtype)
141151

142152

143153
def is_interval_dtype(arr_or_dtype):
154+
if arr_or_dtype is None:
155+
return False
144156
return IntervalDtype.is_dtype(arr_or_dtype)
145157

146158

147159
def is_categorical_dtype(arr_or_dtype):
160+
if arr_or_dtype is None:
161+
return False
148162
return CategoricalDtype.is_dtype(arr_or_dtype)
149163

150164

@@ -178,6 +192,8 @@ def is_string_dtype(arr_or_dtype):
178192

179193
# TODO: gh-15585: consider making the checks stricter.
180194

195+
if arr_or_dtype is None:
196+
return False
181197
try:
182198
dtype = _get_dtype(arr_or_dtype)
183199
return dtype.kind in ('O', 'S', 'U') and not is_period_dtype(dtype)
@@ -224,45 +240,61 @@ def is_dtype_equal(source, target):
224240

225241

226242
def is_any_int_dtype(arr_or_dtype):
243+
if arr_or_dtype is None:
244+
return False
227245
tipo = _get_dtype_type(arr_or_dtype)
228246
return issubclass(tipo, np.integer)
229247

230248

231249
def is_integer_dtype(arr_or_dtype):
250+
if arr_or_dtype is None:
251+
return False
232252
tipo = _get_dtype_type(arr_or_dtype)
233253
return (issubclass(tipo, np.integer) and
234254
not issubclass(tipo, (np.datetime64, np.timedelta64)))
235255

236256

237257
def is_signed_integer_dtype(arr_or_dtype):
258+
if arr_or_dtype is None:
259+
return False
238260
tipo = _get_dtype_type(arr_or_dtype)
239261
return (issubclass(tipo, np.signedinteger) and
240262
not issubclass(tipo, (np.datetime64, np.timedelta64)))
241263

242264

243265
def is_unsigned_integer_dtype(arr_or_dtype):
266+
if arr_or_dtype is None:
267+
return False
244268
tipo = _get_dtype_type(arr_or_dtype)
245269
return (issubclass(tipo, np.unsignedinteger) and
246270
not issubclass(tipo, (np.datetime64, np.timedelta64)))
247271

248272

249273
def is_int64_dtype(arr_or_dtype):
274+
if arr_or_dtype is None:
275+
return False
250276
tipo = _get_dtype_type(arr_or_dtype)
251277
return issubclass(tipo, np.int64)
252278

253279

254280
def is_int_or_datetime_dtype(arr_or_dtype):
281+
if arr_or_dtype is None:
282+
return False
255283
tipo = _get_dtype_type(arr_or_dtype)
256284
return (issubclass(tipo, np.integer) or
257285
issubclass(tipo, (np.datetime64, np.timedelta64)))
258286

259287

260288
def is_datetime64_any_dtype(arr_or_dtype):
289+
if arr_or_dtype is None:
290+
return False
261291
return (is_datetime64_dtype(arr_or_dtype) or
262292
is_datetime64tz_dtype(arr_or_dtype))
263293

264294

265295
def is_datetime64_ns_dtype(arr_or_dtype):
296+
if arr_or_dtype is None:
297+
return False
266298
try:
267299
tipo = _get_dtype(arr_or_dtype)
268300
except TypeError:
@@ -303,6 +335,8 @@ def is_timedelta64_ns_dtype(arr_or_dtype):
303335
False
304336
"""
305337

338+
if arr_or_dtype is None:
339+
return False
306340
try:
307341
tipo = _get_dtype(arr_or_dtype)
308342
return tipo == _TD_DTYPE
@@ -311,6 +345,8 @@ def is_timedelta64_ns_dtype(arr_or_dtype):
311345

312346

313347
def is_datetime_or_timedelta_dtype(arr_or_dtype):
348+
if arr_or_dtype is None:
349+
return False
314350
tipo = _get_dtype_type(arr_or_dtype)
315351
return issubclass(tipo, (np.datetime64, np.timedelta64))
316352

@@ -398,12 +434,16 @@ def is_object(x):
398434

399435

400436
def needs_i8_conversion(arr_or_dtype):
437+
if arr_or_dtype is None:
438+
return False
401439
return (is_datetime_or_timedelta_dtype(arr_or_dtype) or
402440
is_datetime64tz_dtype(arr_or_dtype) or
403441
is_period_dtype(arr_or_dtype))
404442

405443

406444
def is_numeric_dtype(arr_or_dtype):
445+
if arr_or_dtype is None:
446+
return False
407447
tipo = _get_dtype_type(arr_or_dtype)
408448
return (issubclass(tipo, (np.number, np.bool_)) and
409449
not issubclass(tipo, (np.datetime64, np.timedelta64)))
@@ -438,6 +478,8 @@ def is_string_like_dtype(arr_or_dtype):
438478
False
439479
"""
440480

481+
if arr_or_dtype is None:
482+
return False
441483
try:
442484
dtype = _get_dtype(arr_or_dtype)
443485
return dtype.kind in ('S', 'U')
@@ -446,16 +488,22 @@ def is_string_like_dtype(arr_or_dtype):
446488

447489

448490
def is_float_dtype(arr_or_dtype):
491+
if arr_or_dtype is None:
492+
return False
449493
tipo = _get_dtype_type(arr_or_dtype)
450494
return issubclass(tipo, np.floating)
451495

452496

453497
def is_floating_dtype(arr_or_dtype):
498+
if arr_or_dtype is None:
499+
return False
454500
tipo = _get_dtype_type(arr_or_dtype)
455501
return isinstance(tipo, np.floating)
456502

457503

458504
def is_bool_dtype(arr_or_dtype):
505+
if arr_or_dtype is None:
506+
return False
459507
try:
460508
tipo = _get_dtype_type(arr_or_dtype)
461509
except ValueError:
@@ -479,6 +527,8 @@ def is_extension_type(value):
479527

480528

481529
def is_complex_dtype(arr_or_dtype):
530+
if arr_or_dtype is None:
531+
return False
482532
tipo = _get_dtype_type(arr_or_dtype)
483533
return issubclass(tipo, np.complexfloating)
484534

0 commit comments

Comments
 (0)