15
15
16
16
_USE_NUMEXPR = _NUMEXPR_INSTALLED
17
17
_evaluate = None
18
+ _where = None
18
19
19
20
# the set of dtypes that we will allow pass to numexpr
20
- _ALLOWED_DTYPES = set (['int64' ,'int32' ,'float64' ,'float32' ,'bool' ])
21
+ _ALLOWED_DTYPES = dict (evaluate = set (['int64' ,'int32' ,'float64' ,'float32' ,'bool' ]),
22
+ where = set (['int64' ,'float64' ,'bool' ]))
21
23
22
24
# the minimum prod shape that we will use numexpr
23
25
_MIN_ELEMENTS = 10000
@@ -26,17 +28,16 @@ def set_use_numexpr(v = True):
26
28
# set/unset to use numexpr
27
29
global _USE_NUMEXPR
28
30
if _NUMEXPR_INSTALLED :
29
- #print "setting use_numexpr : was->%s, now->%s" % (_USE_NUMEXPR,v)
30
31
_USE_NUMEXPR = v
31
32
32
33
# choose what we are going to do
33
- global _evaluate
34
+ global _evaluate , _where
34
35
if not _USE_NUMEXPR :
35
36
_evaluate = _evaluate_standard
37
+ _where = _where_standard
36
38
else :
37
39
_evaluate = _evaluate_numexpr
38
-
39
- #print "evaluate -> %s" % _evaluate
40
+ _where = _where_numexpr
40
41
41
42
def set_numexpr_threads (n = None ):
42
43
# if we are using numexpr, set the threads to n
@@ -54,7 +55,7 @@ def _evaluate_standard(op, op_str, a, b, raise_on_error=True):
54
55
""" standard evaluation """
55
56
return op (a ,b )
56
57
57
- def _can_use_numexpr (op , op_str , a , b ):
58
+ def _can_use_numexpr (op , op_str , a , b , dtype_check ):
58
59
""" return a boolean if we WILL be using numexpr """
59
60
if op_str is not None :
60
61
@@ -73,15 +74,15 @@ def _can_use_numexpr(op, op_str, a, b):
73
74
dtypes |= set ([o .dtype .name ])
74
75
75
76
# allowed are a superset
76
- if not len (dtypes ) or _ALLOWED_DTYPES >= dtypes :
77
+ if not len (dtypes ) or _ALLOWED_DTYPES [ dtype_check ] >= dtypes :
77
78
return True
78
79
79
80
return False
80
81
81
82
def _evaluate_numexpr (op , op_str , a , b , raise_on_error = False ):
82
83
result = None
83
84
84
- if _can_use_numexpr (op , op_str , a , b ):
85
+ if _can_use_numexpr (op , op_str , a , b , 'evaluate' ):
85
86
try :
86
87
a_value , b_value = a , b
87
88
if hasattr (a_value ,'values' ):
@@ -104,6 +105,40 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False):
104
105
105
106
return result
106
107
108
+ def _where_standard (cond , a , b , raise_on_error = True ):
109
+ return np .where (cond , a , b )
110
+
111
+ def _where_numexpr (cond , a , b , raise_on_error = False ):
112
+ result = None
113
+
114
+ if _can_use_numexpr (None , 'where' , a , b , 'where' ):
115
+
116
+ try :
117
+ cond_value , a_value , b_value = cond , a , b
118
+ if hasattr (cond_value ,'values' ):
119
+ cond_value = cond_value .values
120
+ if hasattr (a_value ,'values' ):
121
+ a_value = a_value .values
122
+ if hasattr (b_value ,'values' ):
123
+ b_value = b_value .values
124
+ result = ne .evaluate ('where(cond_value,a_value,b_value)' ,
125
+ local_dict = { 'cond_value' : cond_value ,
126
+ 'a_value' : a_value ,
127
+ 'b_value' : b_value },
128
+ casting = 'safe' )
129
+ except (ValueError ), detail :
130
+ if 'unknown type object' in str (detail ):
131
+ pass
132
+ except (Exception ), detail :
133
+ if raise_on_error :
134
+ raise TypeError (str (detail ))
135
+
136
+ if result is None :
137
+ result = _where_standard (cond ,a ,b ,raise_on_error )
138
+
139
+ return result
140
+
141
+
107
142
# turn myself on
108
143
set_use_numexpr (True )
109
144
@@ -126,4 +161,20 @@ def evaluate(op, op_str, a, b, raise_on_error=False, use_numexpr=True):
126
161
return _evaluate (op , op_str , a , b , raise_on_error = raise_on_error )
127
162
return _evaluate_standard (op , op_str , a , b , raise_on_error = raise_on_error )
128
163
129
-
164
+ def where (cond , a , b , raise_on_error = False , use_numexpr = True ):
165
+ """ evaluate the where condition cond on a and b
166
+
167
+ Parameters
168
+ ----------
169
+
170
+ cond : a boolean array
171
+ a : return if cond is True
172
+ b : return if cond is False
173
+ raise_on_error : pass the error to the higher level if indicated (default is False),
174
+ otherwise evaluate the op with and return the results
175
+ use_numexpr : whether to try to use numexpr (default True)
176
+ """
177
+
178
+ if use_numexpr :
179
+ return _where (cond , a , b , raise_on_error = raise_on_error )
180
+ return _where_standard (cond , a , b , raise_on_error = raise_on_error )
0 commit comments