19
19
has_c16 = hasattr (np , "complex128" )
20
20
21
21
22
+ @pytest .fixture (params = [True , False ])
23
+ def skipna (request ):
24
+ """
25
+ Fixture to pass skipna to nanops functions.
26
+ """
27
+ return request .param
28
+
29
+
22
30
class TestnanopsDataFrame :
23
31
def setup_method (self , method ):
24
32
np .random .seed (11235 )
@@ -89,38 +97,22 @@ def teardown_method(self, method):
89
97
90
98
def check_results (self , targ , res , axis , check_dtype = True ):
91
99
res = getattr (res , "asm8" , res )
92
- res = getattr (res , "values" , res )
93
-
94
- # timedeltas are a beast here
95
- def _coerce_tds (targ , res ):
96
- if hasattr (targ , "dtype" ) and targ .dtype == "m8[ns]" :
97
- if len (targ ) == 1 :
98
- targ = targ [0 ].item ()
99
- res = res .item ()
100
- else :
101
- targ = targ .view ("i8" )
102
- return targ , res
103
100
104
- try :
105
- if (
106
- axis != 0
107
- and hasattr (targ , "shape" )
108
- and targ .ndim
109
- and targ .shape != res .shape
110
- ):
111
- res = np .split (res , [targ .shape [0 ]], axis = 0 )[0 ]
112
- except (ValueError , IndexError ):
113
- targ , res = _coerce_tds (targ , res )
101
+ if (
102
+ axis != 0
103
+ and hasattr (targ , "shape" )
104
+ and targ .ndim
105
+ and targ .shape != res .shape
106
+ ):
107
+ res = np .split (res , [targ .shape [0 ]], axis = 0 )[0 ]
114
108
115
109
try :
116
110
tm .assert_almost_equal (targ , res , check_dtype = check_dtype )
117
111
except AssertionError :
118
112
119
113
# handle timedelta dtypes
120
114
if hasattr (targ , "dtype" ) and targ .dtype == "m8[ns]" :
121
- targ , res = _coerce_tds (targ , res )
122
- tm .assert_almost_equal (targ , res , check_dtype = check_dtype )
123
- return
115
+ raise
124
116
125
117
# There are sometimes rounding errors with
126
118
# complex and object dtypes.
@@ -149,29 +141,29 @@ def check_fun_data(
149
141
targfunc ,
150
142
testarval ,
151
143
targarval ,
144
+ skipna ,
152
145
check_dtype = True ,
153
146
empty_targfunc = None ,
154
147
** kwargs ,
155
148
):
156
149
for axis in list (range (targarval .ndim )) + [None ]:
157
- for skipna in [False , True ]:
158
- targartempval = targarval if skipna else testarval
159
- if skipna and empty_targfunc and isna (targartempval ).all ():
160
- targ = empty_targfunc (targartempval , axis = axis , ** kwargs )
161
- else :
162
- targ = targfunc (targartempval , axis = axis , ** kwargs )
150
+ targartempval = targarval if skipna else testarval
151
+ if skipna and empty_targfunc and isna (targartempval ).all ():
152
+ targ = empty_targfunc (targartempval , axis = axis , ** kwargs )
153
+ else :
154
+ targ = targfunc (targartempval , axis = axis , ** kwargs )
163
155
164
- res = testfunc (testarval , axis = axis , skipna = skipna , ** kwargs )
156
+ res = testfunc (testarval , axis = axis , skipna = skipna , ** kwargs )
157
+ self .check_results (targ , res , axis , check_dtype = check_dtype )
158
+ if skipna :
159
+ res = testfunc (testarval , axis = axis , ** kwargs )
160
+ self .check_results (targ , res , axis , check_dtype = check_dtype )
161
+ if axis is None :
162
+ res = testfunc (testarval , skipna = skipna , ** kwargs )
163
+ self .check_results (targ , res , axis , check_dtype = check_dtype )
164
+ if skipna and axis is None :
165
+ res = testfunc (testarval , ** kwargs )
165
166
self .check_results (targ , res , axis , check_dtype = check_dtype )
166
- if skipna :
167
- res = testfunc (testarval , axis = axis , ** kwargs )
168
- self .check_results (targ , res , axis , check_dtype = check_dtype )
169
- if axis is None :
170
- res = testfunc (testarval , skipna = skipna , ** kwargs )
171
- self .check_results (targ , res , axis , check_dtype = check_dtype )
172
- if skipna and axis is None :
173
- res = testfunc (testarval , ** kwargs )
174
- self .check_results (targ , res , axis , check_dtype = check_dtype )
175
167
176
168
if testarval .ndim <= 1 :
177
169
return
@@ -184,12 +176,15 @@ def check_fun_data(
184
176
targfunc ,
185
177
testarval2 ,
186
178
targarval2 ,
179
+ skipna = skipna ,
187
180
check_dtype = check_dtype ,
188
181
empty_targfunc = empty_targfunc ,
189
182
** kwargs ,
190
183
)
191
184
192
- def check_fun (self , testfunc , targfunc , testar , empty_targfunc = None , ** kwargs ):
185
+ def check_fun (
186
+ self , testfunc , targfunc , testar , skipna , empty_targfunc = None , ** kwargs
187
+ ):
193
188
194
189
targar = testar
195
190
if testar .endswith ("_nan" ) and hasattr (self , testar [:- 4 ]):
@@ -202,6 +197,7 @@ def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs):
202
197
targfunc ,
203
198
testarval ,
204
199
targarval ,
200
+ skipna = skipna ,
205
201
empty_targfunc = empty_targfunc ,
206
202
** kwargs ,
207
203
)
@@ -210,36 +206,37 @@ def check_funs(
210
206
self ,
211
207
testfunc ,
212
208
targfunc ,
209
+ skipna ,
213
210
allow_complex = True ,
214
211
allow_all_nan = True ,
215
212
allow_date = True ,
216
213
allow_tdelta = True ,
217
214
allow_obj = True ,
218
215
** kwargs ,
219
216
):
220
- self .check_fun (testfunc , targfunc , "arr_float" , ** kwargs )
221
- self .check_fun (testfunc , targfunc , "arr_float_nan" , ** kwargs )
222
- self .check_fun (testfunc , targfunc , "arr_int" , ** kwargs )
223
- self .check_fun (testfunc , targfunc , "arr_bool" , ** kwargs )
217
+ self .check_fun (testfunc , targfunc , "arr_float" , skipna , ** kwargs )
218
+ self .check_fun (testfunc , targfunc , "arr_float_nan" , skipna , ** kwargs )
219
+ self .check_fun (testfunc , targfunc , "arr_int" , skipna , ** kwargs )
220
+ self .check_fun (testfunc , targfunc , "arr_bool" , skipna , ** kwargs )
224
221
objs = [
225
222
self .arr_float .astype ("O" ),
226
223
self .arr_int .astype ("O" ),
227
224
self .arr_bool .astype ("O" ),
228
225
]
229
226
230
227
if allow_all_nan :
231
- self .check_fun (testfunc , targfunc , "arr_nan" , ** kwargs )
228
+ self .check_fun (testfunc , targfunc , "arr_nan" , skipna , ** kwargs )
232
229
233
230
if allow_complex :
234
- self .check_fun (testfunc , targfunc , "arr_complex" , ** kwargs )
235
- self .check_fun (testfunc , targfunc , "arr_complex_nan" , ** kwargs )
231
+ self .check_fun (testfunc , targfunc , "arr_complex" , skipna , ** kwargs )
232
+ self .check_fun (testfunc , targfunc , "arr_complex_nan" , skipna , ** kwargs )
236
233
if allow_all_nan :
237
- self .check_fun (testfunc , targfunc , "arr_nan_nanj" , ** kwargs )
234
+ self .check_fun (testfunc , targfunc , "arr_nan_nanj" , skipna , ** kwargs )
238
235
objs += [self .arr_complex .astype ("O" )]
239
236
240
237
if allow_date :
241
238
targfunc (self .arr_date )
242
- self .check_fun (testfunc , targfunc , "arr_date" , ** kwargs )
239
+ self .check_fun (testfunc , targfunc , "arr_date" , skipna , ** kwargs )
243
240
objs += [self .arr_date .astype ("O" )]
244
241
245
242
if allow_tdelta :
@@ -248,7 +245,7 @@ def check_funs(
248
245
except TypeError :
249
246
pass
250
247
else :
251
- self .check_fun (testfunc , targfunc , "arr_tdelta" , ** kwargs )
248
+ self .check_fun (testfunc , targfunc , "arr_tdelta" , skipna , ** kwargs )
252
249
objs += [self .arr_tdelta .astype ("O" )]
253
250
254
251
if allow_obj :
@@ -260,7 +257,7 @@ def check_funs(
260
257
targfunc = partial (
261
258
self ._badobj_wrap , func = targfunc , allow_complex = allow_complex
262
259
)
263
- self .check_fun (testfunc , targfunc , "arr_obj" , ** kwargs )
260
+ self .check_fun (testfunc , targfunc , "arr_obj" , skipna , ** kwargs )
264
261
265
262
def _badobj_wrap (self , value , func , allow_complex = True , ** kwargs ):
266
263
if value .dtype .kind == "O" :
@@ -273,28 +270,22 @@ def _badobj_wrap(self, value, func, allow_complex=True, **kwargs):
273
270
@pytest .mark .parametrize (
274
271
"nan_op,np_op" , [(nanops .nanany , np .any ), (nanops .nanall , np .all )]
275
272
)
276
- def test_nan_funcs (self , nan_op , np_op ):
277
- # TODO: allow tdelta, doesn't break tests
278
- self .check_funs (
279
- nan_op , np_op , allow_all_nan = False , allow_date = False , allow_tdelta = False
280
- )
273
+ def test_nan_funcs (self , nan_op , np_op , skipna ):
274
+ self .check_funs (nan_op , np_op , skipna , allow_all_nan = False , allow_date = False )
281
275
282
- def test_nansum (self ):
276
+ def test_nansum (self , skipna ):
283
277
self .check_funs (
284
278
nanops .nansum ,
285
279
np .sum ,
280
+ skipna ,
286
281
allow_date = False ,
287
282
check_dtype = False ,
288
283
empty_targfunc = np .nansum ,
289
284
)
290
285
291
- def test_nanmean (self ):
286
+ def test_nanmean (self , skipna ):
292
287
self .check_funs (
293
- nanops .nanmean ,
294
- np .mean ,
295
- allow_complex = False , # TODO: allow this, doesn't break test
296
- allow_obj = False ,
297
- allow_date = False ,
288
+ nanops .nanmean , np .mean , skipna , allow_obj = False , allow_date = False ,
298
289
)
299
290
300
291
def test_nanmean_overflow (self ):
@@ -336,33 +327,36 @@ def test_returned_dtype(self, dtype):
336
327
else :
337
328
assert result .dtype == dtype
338
329
339
- def test_nanmedian (self ):
330
+ def test_nanmedian (self , skipna ):
340
331
with warnings .catch_warnings (record = True ):
341
332
warnings .simplefilter ("ignore" , RuntimeWarning )
342
333
self .check_funs (
343
334
nanops .nanmedian ,
344
335
np .median ,
336
+ skipna ,
345
337
allow_complex = False ,
346
338
allow_date = False ,
347
339
allow_obj = "convert" ,
348
340
)
349
341
350
342
@pytest .mark .parametrize ("ddof" , range (3 ))
351
- def test_nanvar (self , ddof ):
343
+ def test_nanvar (self , ddof , skipna ):
352
344
self .check_funs (
353
345
nanops .nanvar ,
354
346
np .var ,
347
+ skipna ,
355
348
allow_complex = False ,
356
349
allow_date = False ,
357
350
allow_obj = "convert" ,
358
351
ddof = ddof ,
359
352
)
360
353
361
354
@pytest .mark .parametrize ("ddof" , range (3 ))
362
- def test_nanstd (self , ddof ):
355
+ def test_nanstd (self , ddof , skipna ):
363
356
self .check_funs (
364
357
nanops .nanstd ,
365
358
np .std ,
359
+ skipna ,
366
360
allow_complex = False ,
367
361
allow_date = False ,
368
362
allow_obj = "convert" ,
@@ -371,13 +365,14 @@ def test_nanstd(self, ddof):
371
365
372
366
@td .skip_if_no_scipy
373
367
@pytest .mark .parametrize ("ddof" , range (3 ))
374
- def test_nansem (self , ddof ):
368
+ def test_nansem (self , ddof , skipna ):
375
369
from scipy .stats import sem
376
370
377
371
with np .errstate (invalid = "ignore" ):
378
372
self .check_funs (
379
373
nanops .nansem ,
380
374
sem ,
375
+ skipna ,
381
376
allow_complex = False ,
382
377
allow_date = False ,
383
378
allow_tdelta = False ,
@@ -388,10 +383,10 @@ def test_nansem(self, ddof):
388
383
@pytest .mark .parametrize (
389
384
"nan_op,np_op" , [(nanops .nanmin , np .min ), (nanops .nanmax , np .max )]
390
385
)
391
- def test_nanops_with_warnings (self , nan_op , np_op ):
386
+ def test_nanops_with_warnings (self , nan_op , np_op , skipna ):
392
387
with warnings .catch_warnings (record = True ):
393
388
warnings .simplefilter ("ignore" , RuntimeWarning )
394
- self .check_funs (nan_op , np_op , allow_obj = False )
389
+ self .check_funs (nan_op , np_op , skipna , allow_obj = False )
395
390
396
391
def _argminmax_wrap (self , value , axis = None , func = None ):
397
392
res = func (value , axis )
@@ -408,17 +403,17 @@ def _argminmax_wrap(self, value, axis=None, func=None):
408
403
res = - 1
409
404
return res
410
405
411
- def test_nanargmax (self ):
406
+ def test_nanargmax (self , skipna ):
412
407
with warnings .catch_warnings (record = True ):
413
408
warnings .simplefilter ("ignore" , RuntimeWarning )
414
409
func = partial (self ._argminmax_wrap , func = np .argmax )
415
- self .check_funs (nanops .nanargmax , func , allow_obj = False )
410
+ self .check_funs (nanops .nanargmax , func , skipna , allow_obj = False )
416
411
417
- def test_nanargmin (self ):
412
+ def test_nanargmin (self , skipna ):
418
413
with warnings .catch_warnings (record = True ):
419
414
warnings .simplefilter ("ignore" , RuntimeWarning )
420
415
func = partial (self ._argminmax_wrap , func = np .argmin )
421
- self .check_funs (nanops .nanargmin , func , allow_obj = False )
416
+ self .check_funs (nanops .nanargmin , func , skipna , allow_obj = False )
422
417
423
418
def _skew_kurt_wrap (self , values , axis = None , func = None ):
424
419
if not isinstance (values .dtype .type , np .floating ):
@@ -433,21 +428,22 @@ def _skew_kurt_wrap(self, values, axis=None, func=None):
433
428
return result
434
429
435
430
@td .skip_if_no_scipy
436
- def test_nanskew (self ):
431
+ def test_nanskew (self , skipna ):
437
432
from scipy .stats import skew
438
433
439
434
func = partial (self ._skew_kurt_wrap , func = skew )
440
435
with np .errstate (invalid = "ignore" ):
441
436
self .check_funs (
442
437
nanops .nanskew ,
443
438
func ,
439
+ skipna ,
444
440
allow_complex = False ,
445
441
allow_date = False ,
446
442
allow_tdelta = False ,
447
443
)
448
444
449
445
@td .skip_if_no_scipy
450
- def test_nankurt (self ):
446
+ def test_nankurt (self , skipna ):
451
447
from scipy .stats import kurtosis
452
448
453
449
func1 = partial (kurtosis , fisher = True )
@@ -456,15 +452,17 @@ def test_nankurt(self):
456
452
self .check_funs (
457
453
nanops .nankurt ,
458
454
func ,
455
+ skipna ,
459
456
allow_complex = False ,
460
457
allow_date = False ,
461
458
allow_tdelta = False ,
462
459
)
463
460
464
- def test_nanprod (self ):
461
+ def test_nanprod (self , skipna ):
465
462
self .check_funs (
466
463
nanops .nanprod ,
467
464
np .prod ,
465
+ skipna ,
468
466
allow_date = False ,
469
467
allow_tdelta = False ,
470
468
empty_targfunc = np .nanprod ,
0 commit comments