Skip to content

Commit 3bc1f17

Browse files
committed
ENH: move _ensure_{dtype} functions to Cython for speedup, close #1221
1 parent 56051cf commit 3bc1f17

File tree

5 files changed

+88
-30
lines changed

5 files changed

+88
-30
lines changed

pandas/core/common.py

+5-29
Original file line numberDiff line numberDiff line change
@@ -711,36 +711,12 @@ def is_float_dtype(arr_or_dtype):
711711
return issubclass(tipo, np.floating)
712712

713713

714-
def _ensure_float64(arr):
715-
if arr.dtype != np.float64:
716-
arr = arr.astype(np.float64)
717-
return arr
718-
719-
def _ensure_int64(arr):
720-
try:
721-
if arr.dtype != np.int64:
722-
arr = arr.astype(np.int64)
723-
return arr
724-
except AttributeError:
725-
return np.array(arr, dtype=np.int64)
714+
_ensure_float64 = _algos.ensure_float64
715+
_ensure_int64 = _algos.ensure_int64
716+
_ensure_int32 = _algos.ensure_int32
717+
_ensure_platform_int = _algos.ensure_platform_int
718+
_ensure_object = _algos.ensure_object
726719

727-
def _ensure_platform_int(labels):
728-
try:
729-
if labels.dtype != np.int_: # pragma: no cover
730-
labels = labels.astype(np.int_)
731-
return labels
732-
except AttributeError:
733-
return np.array(labels, dtype=np.int_)
734-
735-
def _ensure_int32(arr):
736-
if arr.dtype != np.int32:
737-
arr = arr.astype(np.int32)
738-
return arr
739-
740-
def _ensure_object(arr):
741-
if arr.dtype != np.object_:
742-
arr = arr.astype('O')
743-
return arr
744720

745721
def _astype_nansafe(arr, dtype):
746722
if (np.issubdtype(arr.dtype, np.floating) and

pandas/src/generate_code.py

+31
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,35 @@ def outer_join_indexer_%(name)s(ndarray[%(c_type)s] left,
810810
811811
"""
812812

813+
# ensure_dtype functions
814+
815+
ensure_dtype_template = """
816+
cpdef ensure_%(name)s(object arr):
817+
if util.is_array(arr):
818+
if (<ndarray> arr).descr.type_num == NPY_%(ctype)s:
819+
return arr
820+
else:
821+
return arr.astype(np.%(dtype)s)
822+
else:
823+
return np.array(arr, dtype=np.%(dtype)s)
824+
825+
"""
826+
827+
ensure_functions = [
828+
('float64', 'FLOAT64', 'float64'),
829+
('int32', 'INT32', 'int32'),
830+
('int64', 'INT64', 'int64'),
831+
('platform_int', 'INT', 'int_'),
832+
('object', 'OBJECT', 'object_'),
833+
]
834+
835+
def generate_ensure_dtypes():
836+
output = StringIO()
837+
for name, ctype, dtype in ensure_functions:
838+
filled = ensure_dtype_template % locals()
839+
output.write(filled)
840+
return output.getvalue()
841+
813842
#----------------------------------------------------------------------
814843
# Fast "put" logic for speeding up interleaving logic
815844

@@ -916,6 +945,8 @@ def generate_take_cython_file(path='generated.pyx'):
916945
for template in nobool_1d_templates:
917946
print >> f, generate_from_template(template, exclude=['bool'])
918947

948+
print >> f, generate_ensure_dtypes()
949+
919950
# print >> f, generate_put_functions()
920951

921952
if __name__ == '__main__':

pandas/src/generated.pyx

+51
Original file line numberDiff line numberDiff line change
@@ -3306,3 +3306,54 @@ def inner_join_indexer_int64(ndarray[int64_t] left,
33063306
return result, lindexer, rindexer
33073307

33083308

3309+
3310+
cpdef ensure_float64(object arr):
3311+
if util.is_array(arr):
3312+
if (<ndarray> arr).descr.type_num == NPY_FLOAT64:
3313+
return arr
3314+
else:
3315+
return arr.astype(np.float64)
3316+
else:
3317+
return np.array(arr, dtype=np.float64)
3318+
3319+
3320+
cpdef ensure_int32(object arr):
3321+
if util.is_array(arr):
3322+
if (<ndarray> arr).descr.type_num == NPY_INT32:
3323+
return arr
3324+
else:
3325+
return arr.astype(np.int32)
3326+
else:
3327+
return np.array(arr, dtype=np.int32)
3328+
3329+
3330+
cpdef ensure_int64(object arr):
3331+
if util.is_array(arr):
3332+
if (<ndarray> arr).descr.type_num == NPY_INT64:
3333+
return arr
3334+
else:
3335+
return arr.astype(np.int64)
3336+
else:
3337+
return np.array(arr, dtype=np.int64)
3338+
3339+
3340+
cpdef ensure_platform_int(object arr):
3341+
if util.is_array(arr):
3342+
if (<ndarray> arr).descr.type_num == NPY_INT:
3343+
return arr
3344+
else:
3345+
return arr.astype(np.int_)
3346+
else:
3347+
return np.array(arr, dtype=np.int_)
3348+
3349+
3350+
cpdef ensure_object(object arr):
3351+
if util.is_array(arr):
3352+
if (<ndarray> arr).descr.type_num == NPY_OBJECT:
3353+
return arr
3354+
else:
3355+
return arr.astype(np.object_)
3356+
else:
3357+
return np.array(arr, dtype=np.object_)
3358+
3359+

pandas/src/tseries.pyx

+1
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def value_count_int64(ndarray[int64_t] values):
665665

666666
return result_keys, result_counts
667667

668+
668669
include "hashtable.pyx"
669670
include "datetime.pyx"
670671
include "skiplist.pyx"

pandas/src/util.pxd

-1
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,3 @@ cdef inline bint _checknull(object val):
6060

6161
cdef inline bint _checknan(object val):
6262
return not cnp.PyArray_Check(val) and val != val
63-

0 commit comments

Comments
 (0)