15
15
except ImportError : # pragma: no cover
16
16
_NUMEXPR_INSTALLED = False
17
17
18
+ _TEST_MODE = None
19
+ _TEST_RESULT = None
18
20
_USE_NUMEXPR = _NUMEXPR_INSTALLED
19
21
_evaluate = None
20
22
_where = None
@@ -55,9 +57,10 @@ def set_numexpr_threads(n=None):
55
57
56
58
def _evaluate_standard (op , op_str , a , b , raise_on_error = True , ** eval_kwargs ):
57
59
""" standard evaluation """
60
+ if _TEST_MODE :
61
+ _store_test_result (False )
58
62
return op (a , b )
59
63
60
-
61
64
def _can_use_numexpr (op , op_str , a , b , dtype_check ):
62
65
""" return a boolean if we WILL be using numexpr """
63
66
if op_str is not None :
@@ -88,11 +91,8 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error=False, **eval_kwargs):
88
91
89
92
if _can_use_numexpr (op , op_str , a , b , 'evaluate' ):
90
93
try :
91
- a_value , b_value = a , b
92
- if hasattr (a_value , 'values' ):
93
- a_value = a_value .values
94
- if hasattr (b_value , 'values' ):
95
- b_value = b_value .values
94
+ a_value = getattr (a , "values" , a )
95
+ b_value = getattr (b , "values" , b )
96
96
result = ne .evaluate ('a_value %s b_value' % op_str ,
97
97
local_dict = {'a_value' : a_value ,
98
98
'b_value' : b_value },
@@ -104,6 +104,9 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error=False, **eval_kwargs):
104
104
if raise_on_error :
105
105
raise
106
106
107
+ if _TEST_MODE :
108
+ _store_test_result (result is not None )
109
+
107
110
if result is None :
108
111
result = _evaluate_standard (op , op_str , a , b , raise_on_error )
109
112
@@ -119,13 +122,9 @@ def _where_numexpr(cond, a, b, raise_on_error=False):
119
122
if _can_use_numexpr (None , 'where' , a , b , 'where' ):
120
123
121
124
try :
122
- cond_value , a_value , b_value = cond , a , b
123
- if hasattr (cond_value , 'values' ):
124
- cond_value = cond_value .values
125
- if hasattr (a_value , 'values' ):
126
- a_value = a_value .values
127
- if hasattr (b_value , 'values' ):
128
- b_value = b_value .values
125
+ cond_value = getattr (cond , 'values' , cond )
126
+ a_value = getattr (a , 'values' , a )
127
+ b_value = getattr (b , 'values' , b )
129
128
result = ne .evaluate ('where(cond_value, a_value, b_value)' ,
130
129
local_dict = {'cond_value' : cond_value ,
131
130
'a_value' : a_value ,
@@ -189,3 +188,28 @@ def where(cond, a, b, raise_on_error=False, use_numexpr=True):
189
188
if use_numexpr :
190
189
return _where (cond , a , b , raise_on_error = raise_on_error )
191
190
return _where_standard (cond , a , b , raise_on_error = raise_on_error )
191
+
192
+
193
+ def set_test_mode (v = True ):
194
+ """
195
+ Keeps track of whether numexpr was used. Stores an additional ``True`` for
196
+ every successful use of evaluate with numexpr since the last
197
+ ``get_test_result``
198
+ """
199
+ global _TEST_MODE , _TEST_RESULT
200
+ _TEST_MODE = v
201
+ _TEST_RESULT = []
202
+
203
+
204
+ def _store_test_result (used_numexpr ):
205
+ global _TEST_RESULT
206
+ if used_numexpr :
207
+ _TEST_RESULT .append (used_numexpr )
208
+
209
+
210
+ def get_test_result ():
211
+ """get test result and reset test_results"""
212
+ global _TEST_RESULT
213
+ res = _TEST_RESULT
214
+ _TEST_RESULT = []
215
+ return res
0 commit comments