Skip to content

Commit 08672e3

Browse files
committed
Merge pull request #3154 from jreback/ne2
ENH: added numexpr support for where operations
2 parents fe9d526 + 8bde194 commit 08672e3

File tree

4 files changed

+95
-20
lines changed

4 files changed

+95
-20
lines changed

pandas/core/expressions.py

+60-9
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515

1616
_USE_NUMEXPR = _NUMEXPR_INSTALLED
1717
_evaluate = None
18+
_where = None
1819

1920
# the set of dtypes that we will allow pass to numexpr
20-
_ALLOWED_DTYPES = set(['int64','int32','float64','float32','bool'])
21+
_ALLOWED_DTYPES = dict(evaluate = set(['int64','int32','float64','float32','bool']),
22+
where = set(['int64','float64','bool']))
2123

2224
# the minimum prod shape that we will use numexpr
2325
_MIN_ELEMENTS = 10000
@@ -26,17 +28,16 @@ def set_use_numexpr(v = True):
2628
# set/unset to use numexpr
2729
global _USE_NUMEXPR
2830
if _NUMEXPR_INSTALLED:
29-
#print "setting use_numexpr : was->%s, now->%s" % (_USE_NUMEXPR,v)
3031
_USE_NUMEXPR = v
3132

3233
# choose what we are going to do
33-
global _evaluate
34+
global _evaluate, _where
3435
if not _USE_NUMEXPR:
3536
_evaluate = _evaluate_standard
37+
_where = _where_standard
3638
else:
3739
_evaluate = _evaluate_numexpr
38-
39-
#print "evaluate -> %s" % _evaluate
40+
_where = _where_numexpr
4041

4142
def set_numexpr_threads(n = None):
4243
# if we are using numexpr, set the threads to n
@@ -54,7 +55,7 @@ def _evaluate_standard(op, op_str, a, b, raise_on_error=True):
5455
""" standard evaluation """
5556
return op(a,b)
5657

57-
def _can_use_numexpr(op, op_str, a, b):
58+
def _can_use_numexpr(op, op_str, a, b, dtype_check):
5859
""" return a boolean if we WILL be using numexpr """
5960
if op_str is not None:
6061

@@ -73,15 +74,15 @@ def _can_use_numexpr(op, op_str, a, b):
7374
dtypes |= set([o.dtype.name])
7475

7576
# allowed are a superset
76-
if not len(dtypes) or _ALLOWED_DTYPES >= dtypes:
77+
if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
7778
return True
7879

7980
return False
8081

8182
def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False):
8283
result = None
8384

84-
if _can_use_numexpr(op, op_str, a, b):
85+
if _can_use_numexpr(op, op_str, a, b, 'evaluate'):
8586
try:
8687
a_value, b_value = a, b
8788
if hasattr(a_value,'values'):
@@ -104,6 +105,40 @@ def _evaluate_numexpr(op, op_str, a, b, raise_on_error = False):
104105

105106
return result
106107

108+
def _where_standard(cond, a, b, raise_on_error=True):
109+
return np.where(cond, a, b)
110+
111+
def _where_numexpr(cond, a, b, raise_on_error = False):
112+
result = None
113+
114+
if _can_use_numexpr(None, 'where', a, b, 'where'):
115+
116+
try:
117+
cond_value, a_value, b_value = cond, a, b
118+
if hasattr(cond_value,'values'):
119+
cond_value = cond_value.values
120+
if hasattr(a_value,'values'):
121+
a_value = a_value.values
122+
if hasattr(b_value,'values'):
123+
b_value = b_value.values
124+
result = ne.evaluate('where(cond_value,a_value,b_value)',
125+
local_dict={ 'cond_value' : cond_value,
126+
'a_value' : a_value,
127+
'b_value' : b_value },
128+
casting='safe')
129+
except (ValueError), detail:
130+
if 'unknown type object' in str(detail):
131+
pass
132+
except (Exception), detail:
133+
if raise_on_error:
134+
raise TypeError(str(detail))
135+
136+
if result is None:
137+
result = _where_standard(cond,a,b,raise_on_error)
138+
139+
return result
140+
141+
107142
# turn myself on
108143
set_use_numexpr(True)
109144

@@ -126,4 +161,20 @@ def evaluate(op, op_str, a, b, raise_on_error=False, use_numexpr=True):
126161
return _evaluate(op, op_str, a, b, raise_on_error=raise_on_error)
127162
return _evaluate_standard(op, op_str, a, b, raise_on_error=raise_on_error)
128163

129-
164+
def where(cond, a, b, raise_on_error=False, use_numexpr=True):
165+
""" evaluate the where condition cond on a and b
166+
167+
Parameters
168+
----------
169+
170+
cond : a boolean array
171+
a : return if cond is True
172+
b : return if cond is False
173+
raise_on_error : pass the error to the higher level if indicated (default is False),
174+
otherwise evaluate the op with and return the results
175+
use_numexpr : whether to try to use numexpr (default True)
176+
"""
177+
178+
if use_numexpr:
179+
return _where(cond, a, b, raise_on_error=raise_on_error)
180+
return _where_standard(cond, a, b, raise_on_error=raise_on_error)

pandas/core/frame.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3721,7 +3721,8 @@ def combine_first(self, other):
37213721
-------
37223722
combined : DataFrame
37233723
"""
3724-
combiner = lambda x, y: np.where(isnull(x), y, x)
3724+
def combiner(x, y):
3725+
return expressions.where(isnull(x), y, x, raise_on_error=True)
37253726
return self.combine(other, combiner, overwrite=False)
37263727

37273728
def update(self, other, join='left', overwrite=True, filter_func=None,
@@ -3772,7 +3773,7 @@ def update(self, other, join='left', overwrite=True, filter_func=None,
37723773
else:
37733774
mask = notnull(this)
37743775

3775-
self[col] = np.where(mask, this, that)
3776+
self[col] = expressions.where(mask, this, that, raise_on_error=True)
37763777

37773778
#----------------------------------------------------------------------
37783779
# Misc methods

pandas/core/internals.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pandas.core.common as com
1111
import pandas.lib as lib
1212
import pandas.tslib as tslib
13+
import pandas.core.expressions as expressions
1314

1415
from pandas.tslib import Timestamp
1516
from pandas.util import py3compat
@@ -506,7 +507,7 @@ def func(c,v,o):
506507
return v
507508

508509
try:
509-
return np.where(c,v,o)
510+
return expressions.where(c, v, o, raise_on_error=True)
510511
except (Exception), detail:
511512
if raise_on_error:
512513
raise TypeError('Could not operate [%s] with block values [%s]'

pandas/tests/test_expressions.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,19 @@ def setUp(self):
4646
def test_invalid(self):
4747

4848
# no op
49-
result = expr._can_use_numexpr(operator.add, None, self.frame, self.frame)
49+
result = expr._can_use_numexpr(operator.add, None, self.frame, self.frame, 'evaluate')
5050
self.assert_(result == False)
5151

5252
# mixed
53-
result = expr._can_use_numexpr(operator.add, '+', self.mixed, self.frame)
53+
result = expr._can_use_numexpr(operator.add, '+', self.mixed, self.frame, 'evaluate')
5454
self.assert_(result == False)
5555

5656
# min elements
57-
result = expr._can_use_numexpr(operator.add, '+', self.frame2, self.frame2)
57+
result = expr._can_use_numexpr(operator.add, '+', self.frame2, self.frame2, 'evaluate')
5858
self.assert_(result == False)
5959

6060
# ok, we only check on first part of expression
61-
result = expr._can_use_numexpr(operator.add, '+', self.frame, self.frame2)
61+
result = expr._can_use_numexpr(operator.add, '+', self.frame, self.frame2, 'evaluate')
6262
self.assert_(result == True)
6363

6464
def test_binary_ops(self):
@@ -70,14 +70,14 @@ def testit():
7070
for op, op_str in [('add','+'),('sub','-'),('mul','*'),('div','/'),('pow','**')]:
7171

7272
op = getattr(operator,op)
73-
result = expr._can_use_numexpr(op, op_str, f, f)
73+
result = expr._can_use_numexpr(op, op_str, f, f, 'evaluate')
7474
self.assert_(result == (not f._is_mixed_type))
7575

7676
result = expr.evaluate(op, op_str, f, f, use_numexpr=True)
7777
expected = expr.evaluate(op, op_str, f, f, use_numexpr=False)
7878
assert_array_equal(result,expected.values)
7979

80-
result = expr._can_use_numexpr(op, op_str, f2, f2)
80+
result = expr._can_use_numexpr(op, op_str, f2, f2, 'evaluate')
8181
self.assert_(result == False)
8282

8383

@@ -105,14 +105,14 @@ def testit():
105105

106106
op = getattr(operator,op)
107107

108-
result = expr._can_use_numexpr(op, op_str, f11, f12)
108+
result = expr._can_use_numexpr(op, op_str, f11, f12, 'evaluate')
109109
self.assert_(result == (not f11._is_mixed_type))
110110

111111
result = expr.evaluate(op, op_str, f11, f12, use_numexpr=True)
112112
expected = expr.evaluate(op, op_str, f11, f12, use_numexpr=False)
113113
assert_array_equal(result,expected.values)
114114

115-
result = expr._can_use_numexpr(op, op_str, f21, f22)
115+
result = expr._can_use_numexpr(op, op_str, f21, f22, 'evaluate')
116116
self.assert_(result == False)
117117

118118
expr.set_use_numexpr(False)
@@ -123,6 +123,28 @@ def testit():
123123
expr.set_numexpr_threads()
124124
testit()
125125

126+
def test_where(self):
127+
128+
def testit():
129+
for f in [ self.frame, self.frame2, self.mixed, self.mixed2 ]:
130+
131+
132+
for cond in [ True, False ]:
133+
134+
c = np.empty(f.shape,dtype=np.bool_)
135+
c.fill(cond)
136+
result = expr.where(c, f.values, f.values+1)
137+
expected = np.where(c, f.values, f.values+1)
138+
assert_array_equal(result,expected)
139+
140+
expr.set_use_numexpr(False)
141+
testit()
142+
expr.set_use_numexpr(True)
143+
expr.set_numexpr_threads(1)
144+
testit()
145+
expr.set_numexpr_threads()
146+
testit()
147+
126148
if __name__ == '__main__':
127149
# unittest.main()
128150
import nose

0 commit comments

Comments
 (0)