Skip to content

Commit 102db32

Browse files
jbrockmendelSeeminSyed
authored andcommitted
TST: fixturize skipna in test_nanops (pandas-dev#32607)
1 parent abdae04 commit 102db32

File tree

1 file changed

+75
-77
lines changed

1 file changed

+75
-77
lines changed

pandas/tests/test_nanops.py

+75-77
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
has_c16 = hasattr(np, "complex128")
2020

2121

22+
@pytest.fixture(params=[True, False])
23+
def skipna(request):
24+
"""
25+
Fixture to pass skipna to nanops functions.
26+
"""
27+
return request.param
28+
29+
2230
class TestnanopsDataFrame:
2331
def setup_method(self, method):
2432
np.random.seed(11235)
@@ -89,38 +97,22 @@ def teardown_method(self, method):
8997

9098
def check_results(self, targ, res, axis, check_dtype=True):
9199
res = getattr(res, "asm8", res)
92-
res = getattr(res, "values", res)
93-
94-
# timedeltas are a beast here
95-
def _coerce_tds(targ, res):
96-
if hasattr(targ, "dtype") and targ.dtype == "m8[ns]":
97-
if len(targ) == 1:
98-
targ = targ[0].item()
99-
res = res.item()
100-
else:
101-
targ = targ.view("i8")
102-
return targ, res
103100

104-
try:
105-
if (
106-
axis != 0
107-
and hasattr(targ, "shape")
108-
and targ.ndim
109-
and targ.shape != res.shape
110-
):
111-
res = np.split(res, [targ.shape[0]], axis=0)[0]
112-
except (ValueError, IndexError):
113-
targ, res = _coerce_tds(targ, res)
101+
if (
102+
axis != 0
103+
and hasattr(targ, "shape")
104+
and targ.ndim
105+
and targ.shape != res.shape
106+
):
107+
res = np.split(res, [targ.shape[0]], axis=0)[0]
114108

115109
try:
116110
tm.assert_almost_equal(targ, res, check_dtype=check_dtype)
117111
except AssertionError:
118112

119113
# handle timedelta dtypes
120114
if hasattr(targ, "dtype") and targ.dtype == "m8[ns]":
121-
targ, res = _coerce_tds(targ, res)
122-
tm.assert_almost_equal(targ, res, check_dtype=check_dtype)
123-
return
115+
raise
124116

125117
# There are sometimes rounding errors with
126118
# complex and object dtypes.
@@ -149,29 +141,29 @@ def check_fun_data(
149141
targfunc,
150142
testarval,
151143
targarval,
144+
skipna,
152145
check_dtype=True,
153146
empty_targfunc=None,
154147
**kwargs,
155148
):
156149
for axis in list(range(targarval.ndim)) + [None]:
157-
for skipna in [False, True]:
158-
targartempval = targarval if skipna else testarval
159-
if skipna and empty_targfunc and isna(targartempval).all():
160-
targ = empty_targfunc(targartempval, axis=axis, **kwargs)
161-
else:
162-
targ = targfunc(targartempval, axis=axis, **kwargs)
150+
targartempval = targarval if skipna else testarval
151+
if skipna and empty_targfunc and isna(targartempval).all():
152+
targ = empty_targfunc(targartempval, axis=axis, **kwargs)
153+
else:
154+
targ = targfunc(targartempval, axis=axis, **kwargs)
163155

164-
res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs)
156+
res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs)
157+
self.check_results(targ, res, axis, check_dtype=check_dtype)
158+
if skipna:
159+
res = testfunc(testarval, axis=axis, **kwargs)
160+
self.check_results(targ, res, axis, check_dtype=check_dtype)
161+
if axis is None:
162+
res = testfunc(testarval, skipna=skipna, **kwargs)
163+
self.check_results(targ, res, axis, check_dtype=check_dtype)
164+
if skipna and axis is None:
165+
res = testfunc(testarval, **kwargs)
165166
self.check_results(targ, res, axis, check_dtype=check_dtype)
166-
if skipna:
167-
res = testfunc(testarval, axis=axis, **kwargs)
168-
self.check_results(targ, res, axis, check_dtype=check_dtype)
169-
if axis is None:
170-
res = testfunc(testarval, skipna=skipna, **kwargs)
171-
self.check_results(targ, res, axis, check_dtype=check_dtype)
172-
if skipna and axis is None:
173-
res = testfunc(testarval, **kwargs)
174-
self.check_results(targ, res, axis, check_dtype=check_dtype)
175167

176168
if testarval.ndim <= 1:
177169
return
@@ -184,12 +176,15 @@ def check_fun_data(
184176
targfunc,
185177
testarval2,
186178
targarval2,
179+
skipna=skipna,
187180
check_dtype=check_dtype,
188181
empty_targfunc=empty_targfunc,
189182
**kwargs,
190183
)
191184

192-
def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs):
185+
def check_fun(
186+
self, testfunc, targfunc, testar, skipna, empty_targfunc=None, **kwargs
187+
):
193188

194189
targar = testar
195190
if testar.endswith("_nan") and hasattr(self, testar[:-4]):
@@ -202,6 +197,7 @@ def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs):
202197
targfunc,
203198
testarval,
204199
targarval,
200+
skipna=skipna,
205201
empty_targfunc=empty_targfunc,
206202
**kwargs,
207203
)
@@ -210,36 +206,37 @@ def check_funs(
210206
self,
211207
testfunc,
212208
targfunc,
209+
skipna,
213210
allow_complex=True,
214211
allow_all_nan=True,
215212
allow_date=True,
216213
allow_tdelta=True,
217214
allow_obj=True,
218215
**kwargs,
219216
):
220-
self.check_fun(testfunc, targfunc, "arr_float", **kwargs)
221-
self.check_fun(testfunc, targfunc, "arr_float_nan", **kwargs)
222-
self.check_fun(testfunc, targfunc, "arr_int", **kwargs)
223-
self.check_fun(testfunc, targfunc, "arr_bool", **kwargs)
217+
self.check_fun(testfunc, targfunc, "arr_float", skipna, **kwargs)
218+
self.check_fun(testfunc, targfunc, "arr_float_nan", skipna, **kwargs)
219+
self.check_fun(testfunc, targfunc, "arr_int", skipna, **kwargs)
220+
self.check_fun(testfunc, targfunc, "arr_bool", skipna, **kwargs)
224221
objs = [
225222
self.arr_float.astype("O"),
226223
self.arr_int.astype("O"),
227224
self.arr_bool.astype("O"),
228225
]
229226

230227
if allow_all_nan:
231-
self.check_fun(testfunc, targfunc, "arr_nan", **kwargs)
228+
self.check_fun(testfunc, targfunc, "arr_nan", skipna, **kwargs)
232229

233230
if allow_complex:
234-
self.check_fun(testfunc, targfunc, "arr_complex", **kwargs)
235-
self.check_fun(testfunc, targfunc, "arr_complex_nan", **kwargs)
231+
self.check_fun(testfunc, targfunc, "arr_complex", skipna, **kwargs)
232+
self.check_fun(testfunc, targfunc, "arr_complex_nan", skipna, **kwargs)
236233
if allow_all_nan:
237-
self.check_fun(testfunc, targfunc, "arr_nan_nanj", **kwargs)
234+
self.check_fun(testfunc, targfunc, "arr_nan_nanj", skipna, **kwargs)
238235
objs += [self.arr_complex.astype("O")]
239236

240237
if allow_date:
241238
targfunc(self.arr_date)
242-
self.check_fun(testfunc, targfunc, "arr_date", **kwargs)
239+
self.check_fun(testfunc, targfunc, "arr_date", skipna, **kwargs)
243240
objs += [self.arr_date.astype("O")]
244241

245242
if allow_tdelta:
@@ -248,7 +245,7 @@ def check_funs(
248245
except TypeError:
249246
pass
250247
else:
251-
self.check_fun(testfunc, targfunc, "arr_tdelta", **kwargs)
248+
self.check_fun(testfunc, targfunc, "arr_tdelta", skipna, **kwargs)
252249
objs += [self.arr_tdelta.astype("O")]
253250

254251
if allow_obj:
@@ -260,7 +257,7 @@ def check_funs(
260257
targfunc = partial(
261258
self._badobj_wrap, func=targfunc, allow_complex=allow_complex
262259
)
263-
self.check_fun(testfunc, targfunc, "arr_obj", **kwargs)
260+
self.check_fun(testfunc, targfunc, "arr_obj", skipna, **kwargs)
264261

265262
def _badobj_wrap(self, value, func, allow_complex=True, **kwargs):
266263
if value.dtype.kind == "O":
@@ -273,28 +270,22 @@ def _badobj_wrap(self, value, func, allow_complex=True, **kwargs):
273270
@pytest.mark.parametrize(
274271
"nan_op,np_op", [(nanops.nanany, np.any), (nanops.nanall, np.all)]
275272
)
276-
def test_nan_funcs(self, nan_op, np_op):
277-
# TODO: allow tdelta, doesn't break tests
278-
self.check_funs(
279-
nan_op, np_op, allow_all_nan=False, allow_date=False, allow_tdelta=False
280-
)
273+
def test_nan_funcs(self, nan_op, np_op, skipna):
274+
self.check_funs(nan_op, np_op, skipna, allow_all_nan=False, allow_date=False)
281275

282-
def test_nansum(self):
276+
def test_nansum(self, skipna):
283277
self.check_funs(
284278
nanops.nansum,
285279
np.sum,
280+
skipna,
286281
allow_date=False,
287282
check_dtype=False,
288283
empty_targfunc=np.nansum,
289284
)
290285

291-
def test_nanmean(self):
286+
def test_nanmean(self, skipna):
292287
self.check_funs(
293-
nanops.nanmean,
294-
np.mean,
295-
allow_complex=False, # TODO: allow this, doesn't break test
296-
allow_obj=False,
297-
allow_date=False,
288+
nanops.nanmean, np.mean, skipna, allow_obj=False, allow_date=False,
298289
)
299290

300291
def test_nanmean_overflow(self):
@@ -336,33 +327,36 @@ def test_returned_dtype(self, dtype):
336327
else:
337328
assert result.dtype == dtype
338329

339-
def test_nanmedian(self):
330+
def test_nanmedian(self, skipna):
340331
with warnings.catch_warnings(record=True):
341332
warnings.simplefilter("ignore", RuntimeWarning)
342333
self.check_funs(
343334
nanops.nanmedian,
344335
np.median,
336+
skipna,
345337
allow_complex=False,
346338
allow_date=False,
347339
allow_obj="convert",
348340
)
349341

350342
@pytest.mark.parametrize("ddof", range(3))
351-
def test_nanvar(self, ddof):
343+
def test_nanvar(self, ddof, skipna):
352344
self.check_funs(
353345
nanops.nanvar,
354346
np.var,
347+
skipna,
355348
allow_complex=False,
356349
allow_date=False,
357350
allow_obj="convert",
358351
ddof=ddof,
359352
)
360353

361354
@pytest.mark.parametrize("ddof", range(3))
362-
def test_nanstd(self, ddof):
355+
def test_nanstd(self, ddof, skipna):
363356
self.check_funs(
364357
nanops.nanstd,
365358
np.std,
359+
skipna,
366360
allow_complex=False,
367361
allow_date=False,
368362
allow_obj="convert",
@@ -371,13 +365,14 @@ def test_nanstd(self, ddof):
371365

372366
@td.skip_if_no_scipy
373367
@pytest.mark.parametrize("ddof", range(3))
374-
def test_nansem(self, ddof):
368+
def test_nansem(self, ddof, skipna):
375369
from scipy.stats import sem
376370

377371
with np.errstate(invalid="ignore"):
378372
self.check_funs(
379373
nanops.nansem,
380374
sem,
375+
skipna,
381376
allow_complex=False,
382377
allow_date=False,
383378
allow_tdelta=False,
@@ -388,10 +383,10 @@ def test_nansem(self, ddof):
388383
@pytest.mark.parametrize(
389384
"nan_op,np_op", [(nanops.nanmin, np.min), (nanops.nanmax, np.max)]
390385
)
391-
def test_nanops_with_warnings(self, nan_op, np_op):
386+
def test_nanops_with_warnings(self, nan_op, np_op, skipna):
392387
with warnings.catch_warnings(record=True):
393388
warnings.simplefilter("ignore", RuntimeWarning)
394-
self.check_funs(nan_op, np_op, allow_obj=False)
389+
self.check_funs(nan_op, np_op, skipna, allow_obj=False)
395390

396391
def _argminmax_wrap(self, value, axis=None, func=None):
397392
res = func(value, axis)
@@ -408,17 +403,17 @@ def _argminmax_wrap(self, value, axis=None, func=None):
408403
res = -1
409404
return res
410405

411-
def test_nanargmax(self):
406+
def test_nanargmax(self, skipna):
412407
with warnings.catch_warnings(record=True):
413408
warnings.simplefilter("ignore", RuntimeWarning)
414409
func = partial(self._argminmax_wrap, func=np.argmax)
415-
self.check_funs(nanops.nanargmax, func, allow_obj=False)
410+
self.check_funs(nanops.nanargmax, func, skipna, allow_obj=False)
416411

417-
def test_nanargmin(self):
412+
def test_nanargmin(self, skipna):
418413
with warnings.catch_warnings(record=True):
419414
warnings.simplefilter("ignore", RuntimeWarning)
420415
func = partial(self._argminmax_wrap, func=np.argmin)
421-
self.check_funs(nanops.nanargmin, func, allow_obj=False)
416+
self.check_funs(nanops.nanargmin, func, skipna, allow_obj=False)
422417

423418
def _skew_kurt_wrap(self, values, axis=None, func=None):
424419
if not isinstance(values.dtype.type, np.floating):
@@ -433,21 +428,22 @@ def _skew_kurt_wrap(self, values, axis=None, func=None):
433428
return result
434429

435430
@td.skip_if_no_scipy
436-
def test_nanskew(self):
431+
def test_nanskew(self, skipna):
437432
from scipy.stats import skew
438433

439434
func = partial(self._skew_kurt_wrap, func=skew)
440435
with np.errstate(invalid="ignore"):
441436
self.check_funs(
442437
nanops.nanskew,
443438
func,
439+
skipna,
444440
allow_complex=False,
445441
allow_date=False,
446442
allow_tdelta=False,
447443
)
448444

449445
@td.skip_if_no_scipy
450-
def test_nankurt(self):
446+
def test_nankurt(self, skipna):
451447
from scipy.stats import kurtosis
452448

453449
func1 = partial(kurtosis, fisher=True)
@@ -456,15 +452,17 @@ def test_nankurt(self):
456452
self.check_funs(
457453
nanops.nankurt,
458454
func,
455+
skipna,
459456
allow_complex=False,
460457
allow_date=False,
461458
allow_tdelta=False,
462459
)
463460

464-
def test_nanprod(self):
461+
def test_nanprod(self, skipna):
465462
self.check_funs(
466463
nanops.nanprod,
467464
np.prod,
465+
skipna,
468466
allow_date=False,
469467
allow_tdelta=False,
470468
empty_targfunc=np.nanprod,

0 commit comments

Comments
 (0)