18
18
from pandas .tests .frame .common import zip_frames
19
19
20
20
21
+ @pytest .fixture (params = ["python" , "numba" ])
22
+ def engine (request ):
23
+ if request .param == "numba" :
24
+ pytest .importorskip ("numba" )
25
+ return request .param
26
+
27
+
21
28
def test_apply (float_frame ):
22
29
with np .errstate (all = "ignore" ):
23
30
# ufunc
@@ -234,36 +241,42 @@ def test_apply_broadcast_series_lambda_func(int_frame_const_col):
234
241
235
242
236
243
@pytest .mark .parametrize ("axis" , [0 , 1 ])
237
- def test_apply_raw_float_frame (float_frame , axis ):
244
+ def test_apply_raw_float_frame (float_frame , axis , engine ):
245
+ if engine == "numba" :
246
+ pytest .skip ("numba can't handle when UDF returns None." )
247
+
238
248
def _assert_raw (x ):
239
249
assert isinstance (x , np .ndarray )
240
250
assert x .ndim == 1
241
251
242
- float_frame .apply (_assert_raw , axis = axis , raw = True )
252
+ float_frame .apply (_assert_raw , axis = axis , engine = engine , raw = True )
243
253
244
254
245
255
@pytest .mark .parametrize ("axis" , [0 , 1 ])
246
- def test_apply_raw_float_frame_lambda (float_frame , axis ):
247
- result = float_frame .apply (np .mean , axis = axis , raw = True )
256
+ def test_apply_raw_float_frame_lambda (float_frame , axis , engine ):
257
+ result = float_frame .apply (np .mean , axis = axis , engine = engine , raw = True )
248
258
expected = float_frame .apply (lambda x : x .values .mean (), axis = axis )
249
259
tm .assert_series_equal (result , expected )
250
260
251
261
252
- def test_apply_raw_float_frame_no_reduction (float_frame ):
262
+ def test_apply_raw_float_frame_no_reduction (float_frame , engine ):
253
263
# no reduction
254
- result = float_frame .apply (lambda x : x * 2 , raw = True )
264
+ result = float_frame .apply (lambda x : x * 2 , engine = engine , raw = True )
255
265
expected = float_frame * 2
256
266
tm .assert_frame_equal (result , expected )
257
267
258
268
259
269
@pytest .mark .parametrize ("axis" , [0 , 1 ])
260
- def test_apply_raw_mixed_type_frame (mixed_type_frame , axis ):
270
+ def test_apply_raw_mixed_type_frame (mixed_type_frame , axis , engine ):
271
+ if engine == "numba" :
272
+ pytest .skip ("isinstance check doesn't work with numba" )
273
+
261
274
def _assert_raw (x ):
262
275
assert isinstance (x , np .ndarray )
263
276
assert x .ndim == 1
264
277
265
278
# Mixed dtype (GH-32423)
266
- mixed_type_frame .apply (_assert_raw , axis = axis , raw = True )
279
+ mixed_type_frame .apply (_assert_raw , axis = axis , engine = engine , raw = True )
267
280
268
281
269
282
def test_apply_axis1 (float_frame ):
@@ -300,14 +313,20 @@ def test_apply_mixed_dtype_corner_indexing():
300
313
)
301
314
@pytest .mark .parametrize ("raw" , [True , False ])
302
315
@pytest .mark .parametrize ("axis" , [0 , 1 ])
303
- def test_apply_empty_infer_type (ax , func , raw , axis ):
316
+ def test_apply_empty_infer_type (ax , func , raw , axis , engine , request ):
304
317
df = DataFrame (** {ax : ["a" , "b" , "c" ]})
305
318
306
319
with np .errstate (all = "ignore" ):
307
320
test_res = func (np .array ([], dtype = "f8" ))
308
321
is_reduction = not isinstance (test_res , np .ndarray )
309
322
310
- result = df .apply (func , axis = axis , raw = raw )
323
+ if engine == "numba" and raw is False :
324
+ mark = pytest .mark .xfail (
325
+ reason = "numba engine only supports raw=True at the moment"
326
+ )
327
+ request .node .add_marker (mark )
328
+
329
+ result = df .apply (func , axis = axis , engine = engine , raw = raw )
311
330
if is_reduction :
312
331
agg_axis = df ._get_agg_axis (axis )
313
332
assert isinstance (result , Series )
@@ -607,8 +626,10 @@ def non_reducing_function(row):
607
626
assert names == list (df .index )
608
627
609
628
610
- def test_apply_raw_function_runs_once ():
629
+ def test_apply_raw_function_runs_once (engine ):
611
630
# https://github.com/pandas-dev/pandas/issues/34506
631
+ if engine == "numba" :
632
+ pytest .skip ("appending to list outside of numba func is not supported" )
612
633
613
634
df = DataFrame ({"a" : [1 , 2 , 3 ]})
614
635
values = [] # Save row values function is applied to
@@ -623,7 +644,7 @@ def non_reducing_function(row):
623
644
for func in [reducing_function , non_reducing_function ]:
624
645
del values [:]
625
646
626
- df .apply (func , raw = True , axis = 1 )
647
+ df .apply (func , engine = engine , raw = True , axis = 1 )
627
648
assert values == list (df .a .to_list ())
628
649
629
650
@@ -1449,10 +1470,12 @@ def test_apply_no_suffix_index():
1449
1470
tm .assert_frame_equal (result , expected )
1450
1471
1451
1472
1452
- def test_apply_raw_returns_string ():
1473
+ def test_apply_raw_returns_string (engine ):
1453
1474
# https://github.com/pandas-dev/pandas/issues/35940
1475
+ if engine == "numba" :
1476
+ pytest .skip ("No object dtype support in numba" )
1454
1477
df = DataFrame ({"A" : ["aa" , "bbb" ]})
1455
- result = df .apply (lambda x : x [0 ], axis = 1 , raw = True )
1478
+ result = df .apply (lambda x : x [0 ], engine = engine , axis = 1 , raw = True )
1456
1479
expected = Series (["aa" , "bbb" ])
1457
1480
tm .assert_series_equal (result , expected )
1458
1481
@@ -1632,3 +1655,14 @@ def test_agg_dist_like_and_nonunique_columns():
1632
1655
result = df .agg ({"A" : "count" })
1633
1656
expected = df ["A" ].count ()
1634
1657
tm .assert_series_equal (result , expected )
1658
+
1659
+
1660
+ def test_numba_unsupported ():
1661
+ df = DataFrame (
1662
+ {"A" : [None , 2 , 3 ], "B" : [1.0 , np .nan , 3.0 ], "C" : ["foo" , None , "bar" ]}
1663
+ )
1664
+ with pytest .raises (
1665
+ ValueError ,
1666
+ match = "The numba engine in DataFrame.apply can only be used when raw=True" ,
1667
+ ):
1668
+ df .apply (lambda x : x , engine = "numba" , raw = False )
0 commit comments