@@ -54,14 +54,12 @@ def run_arithmetic(self, df, other):
54
54
operations = ["add" , "sub" , "mul" , "mod" , "truediv" , "floordiv" ]
55
55
for test_flex in [True , False ]:
56
56
for arith in operations :
57
-
58
- operator_name = arith
59
-
57
+ # TODO: share with run_binary
60
58
if test_flex :
61
59
op = lambda x , y : getattr (x , arith )(y )
62
60
op .__name__ = arith
63
61
else :
64
- op = getattr (operator , operator_name )
62
+ op = getattr (operator , arith )
65
63
expr .set_use_numexpr (False )
66
64
expected = op (df , other )
67
65
expr .set_use_numexpr (True )
@@ -87,13 +85,14 @@ def run_binary(self, df, other):
87
85
for test_flex in [True , False ]:
88
86
for arith in operations :
89
87
if test_flex :
90
- op = lambda x , y : getattr (df , arith )(y )
88
+ op = lambda x , y : getattr (x , arith )(y )
91
89
op .__name__ = arith
92
90
else :
93
91
op = getattr (operator , arith )
94
92
expr .set_use_numexpr (False )
95
93
expected = op (df , other )
96
94
expr .set_use_numexpr (True )
95
+
97
96
expr .get_test_result ()
98
97
result = op (df , other )
99
98
used_numexpr = expr .get_test_result ()
@@ -167,29 +166,29 @@ def test_invalid(self):
167
166
"opname,op_str" ,
168
167
[("add" , "+" ), ("sub" , "-" ), ("mul" , "*" ), ("truediv" , "/" ), ("pow" , "**" )],
169
168
)
170
- def test_binary_ops (self , opname , op_str ):
169
+ @pytest .mark .parametrize ("left,right" , [(_frame , _frame2 ), (_mixed , _mixed2 )])
170
+ def test_binary_ops (self , opname , op_str , left , right ):
171
171
def testit ():
172
172
173
- for f , f2 in [(self .frame , self .frame2 ), (self .mixed , self .mixed2 )]:
173
+ if opname == "pow" :
174
+ # TODO: get this working
175
+ return
174
176
175
- if opname == "pow" :
176
- continue
177
+ op = getattr (operator , opname )
177
178
178
- op = getattr (operator , opname )
179
+ result = expr ._can_use_numexpr (op , op_str , left , left , "evaluate" )
180
+ assert result != left ._is_mixed_type
179
181
180
- result = expr ._can_use_numexpr (op , op_str , f , f , "evaluate" )
181
- assert result != f . _is_mixed_type
182
+ result = expr .evaluate (op , op_str , left , left , use_numexpr = True )
183
+ expected = expr . evaluate ( op , op_str , left , left , use_numexpr = False )
182
184
183
- result = expr .evaluate (op , op_str , f , f , use_numexpr = True )
184
- expected = expr .evaluate (op , op_str , f , f , use_numexpr = False )
185
+ if isinstance (result , DataFrame ):
186
+ tm .assert_frame_equal (result , expected )
187
+ else :
188
+ tm .assert_numpy_array_equal (result , expected .values )
185
189
186
- if isinstance (result , DataFrame ):
187
- tm .assert_frame_equal (result , expected )
188
- else :
189
- tm .assert_numpy_array_equal (result , expected .values )
190
-
191
- result = expr ._can_use_numexpr (op , op_str , f2 , f2 , "evaluate" )
192
- assert not result
190
+ result = expr ._can_use_numexpr (op , op_str , right , right , "evaluate" )
191
+ assert not result
193
192
194
193
expr .set_use_numexpr (False )
195
194
testit ()
@@ -210,30 +209,26 @@ def testit():
210
209
("ne" , "!=" ),
211
210
],
212
211
)
213
- def test_comparison_ops (self , opname , op_str ):
212
+ @pytest .mark .parametrize ("left,right" , [(_frame , _frame2 ), (_mixed , _mixed2 )])
213
+ def test_comparison_ops (self , opname , op_str , left , right ):
214
214
def testit ():
215
- for f , f2 in [(self .frame , self .frame2 ), (self .mixed , self .mixed2 )]:
216
-
217
- f11 = f
218
- f12 = f + 1
215
+ f12 = left + 1
216
+ f22 = right + 1
219
217
220
- f21 = f2
221
- f22 = f2 + 1
218
+ op = getattr (operator , opname )
222
219
223
- op = getattr (operator , opname )
220
+ result = expr ._can_use_numexpr (op , op_str , left , f12 , "evaluate" )
221
+ assert result != left ._is_mixed_type
224
222
225
- result = expr ._can_use_numexpr (op , op_str , f11 , f12 , "evaluate" )
226
- assert result != f11 ._is_mixed_type
223
+ result = expr .evaluate (op , op_str , left , f12 , use_numexpr = True )
224
+ expected = expr .evaluate (op , op_str , left , f12 , use_numexpr = False )
225
+ if isinstance (result , DataFrame ):
226
+ tm .assert_frame_equal (result , expected )
227
+ else :
228
+ tm .assert_numpy_array_equal (result , expected .values )
227
229
228
- result = expr .evaluate (op , op_str , f11 , f12 , use_numexpr = True )
229
- expected = expr .evaluate (op , op_str , f11 , f12 , use_numexpr = False )
230
- if isinstance (result , DataFrame ):
231
- tm .assert_frame_equal (result , expected )
232
- else :
233
- tm .assert_numpy_array_equal (result , expected .values )
234
-
235
- result = expr ._can_use_numexpr (op , op_str , f21 , f22 , "evaluate" )
236
- assert not result
230
+ result = expr ._can_use_numexpr (op , op_str , right , f22 , "evaluate" )
231
+ assert not result
237
232
238
233
expr .set_use_numexpr (False )
239
234
testit ()
@@ -244,15 +239,14 @@ def testit():
244
239
testit ()
245
240
246
241
@pytest .mark .parametrize ("cond" , [True , False ])
247
- def test_where (self , cond ):
242
+ @pytest .mark .parametrize ("df" , [_frame , _frame2 , _mixed , _mixed2 ])
243
+ def test_where (self , cond , df ):
248
244
def testit ():
249
- for f in [self .frame , self .frame2 , self .mixed , self .mixed2 ]:
250
-
251
- c = np .empty (f .shape , dtype = np .bool_ )
252
- c .fill (cond )
253
- result = expr .where (c , f .values , f .values + 1 )
254
- expected = np .where (c , f .values , f .values + 1 )
255
- tm .assert_numpy_array_equal (result , expected )
245
+ c = np .empty (df .shape , dtype = np .bool_ )
246
+ c .fill (cond )
247
+ result = expr .where (c , df .values , df .values + 1 )
248
+ expected = np .where (c , df .values , df .values + 1 )
249
+ tm .assert_numpy_array_equal (result , expected )
256
250
257
251
expr .set_use_numexpr (False )
258
252
testit ()
@@ -263,7 +257,7 @@ def testit():
263
257
testit ()
264
258
265
259
@pytest .mark .parametrize (
266
- "op_str,opname" , list ( zip ([ "/" , "//" , "**" ], [ "truediv" , "floordiv " , "pow" ]))
260
+ "op_str,opname" , [( "/" , "truediv" ), ( "//" , "floordiv" ), ( "** " , "pow" )]
267
261
)
268
262
def test_bool_ops_raise_on_arithmetic (self , op_str , opname ):
269
263
df = DataFrame ({"a" : np .random .rand (10 ) > 0.5 , "b" : np .random .rand (10 ) > 0.5 })
@@ -291,7 +285,7 @@ def test_bool_ops_raise_on_arithmetic(self, op_str, opname):
291
285
f (df , True )
292
286
293
287
@pytest .mark .parametrize (
294
- "op_str,opname" , list ( zip ([ "+" , "*" , "-" ], [ "add" , "mul " , "sub" ]))
288
+ "op_str,opname" , [( "+" , "add" ), ( "*" , "mul" ), ( "- " , "sub" )]
295
289
)
296
290
def test_bool_ops_warn_on_arithmetic (self , op_str , opname ):
297
291
n = 10
0 commit comments