@@ -810,6 +810,35 @@ def outer_join_indexer_%(name)s(ndarray[%(c_type)s] left,
810
810
811
811
"""
812
812
813
+ # ensure_dtype functions
814
+
815
+ ensure_dtype_template = """
816
+ cpdef ensure_%(name)s(object arr):
817
+ if util.is_array(arr):
818
+ if (<ndarray> arr).descr.type_num == NPY_%(ctype)s:
819
+ return arr
820
+ else:
821
+ return arr.astype(np.%(dtype)s)
822
+ else:
823
+ return np.array(arr, dtype=np.%(dtype)s)
824
+
825
+ """
826
+
827
+ ensure_functions = [
828
+ ('float64' , 'FLOAT64' , 'float64' ),
829
+ ('int32' , 'INT32' , 'int32' ),
830
+ ('int64' , 'INT64' , 'int64' ),
831
+ ('platform_int' , 'INT' , 'int_' ),
832
+ ('object' , 'OBJECT' , 'object_' ),
833
+ ]
834
+
835
+ def generate_ensure_dtypes ():
836
+ output = StringIO ()
837
+ for name , ctype , dtype in ensure_functions :
838
+ filled = ensure_dtype_template % locals ()
839
+ output .write (filled )
840
+ return output .getvalue ()
841
+
813
842
#----------------------------------------------------------------------
814
843
# Fast "put" logic for speeding up interleaving logic
815
844
@@ -916,6 +945,8 @@ def generate_take_cython_file(path='generated.pyx'):
916
945
for template in nobool_1d_templates :
917
946
print >> f , generate_from_template (template , exclude = ['bool' ])
918
947
948
+ print >> f , generate_ensure_dtypes ()
949
+
919
950
# print >> f, generate_put_functions()
920
951
921
952
if __name__ == '__main__' :
0 commit comments