Skip to content

TST: fixturize skipna in test_nanops #32607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 75 additions & 77 deletions pandas/tests/test_nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
has_c16 = hasattr(np, "complex128")


@pytest.fixture(params=[True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be in pandas/conftest.py (can move later)

def skipna(request):
"""
Fixture to pass skipna to nanops functions.
"""
return request.param


class TestnanopsDataFrame:
def setup_method(self, method):
np.random.seed(11235)
Expand Down Expand Up @@ -89,38 +97,22 @@ def teardown_method(self, method):

def check_results(self, targ, res, axis, check_dtype=True):
res = getattr(res, "asm8", res)
res = getattr(res, "values", res)

# timedeltas are a beast here
def _coerce_tds(targ, res):
if hasattr(targ, "dtype") and targ.dtype == "m8[ns]":
if len(targ) == 1:
targ = targ[0].item()
res = res.item()
else:
targ = targ.view("i8")
return targ, res

try:
if (
axis != 0
and hasattr(targ, "shape")
and targ.ndim
and targ.shape != res.shape
):
res = np.split(res, [targ.shape[0]], axis=0)[0]
except (ValueError, IndexError):
targ, res = _coerce_tds(targ, res)
if (
axis != 0
and hasattr(targ, "shape")
and targ.ndim
and targ.shape != res.shape
):
res = np.split(res, [targ.shape[0]], axis=0)[0]

try:
tm.assert_almost_equal(targ, res, check_dtype=check_dtype)
except AssertionError:

# handle timedelta dtypes
if hasattr(targ, "dtype") and targ.dtype == "m8[ns]":
targ, res = _coerce_tds(targ, res)
tm.assert_almost_equal(targ, res, check_dtype=check_dtype)
return
raise

# There are sometimes rounding errors with
# complex and object dtypes.
Expand Down Expand Up @@ -149,29 +141,29 @@ def check_fun_data(
targfunc,
testarval,
targarval,
skipna,
check_dtype=True,
empty_targfunc=None,
**kwargs,
):
for axis in list(range(targarval.ndim)) + [None]:
for skipna in [False, True]:
targartempval = targarval if skipna else testarval
if skipna and empty_targfunc and isna(targartempval).all():
targ = empty_targfunc(targartempval, axis=axis, **kwargs)
else:
targ = targfunc(targartempval, axis=axis, **kwargs)
targartempval = targarval if skipna else testarval
if skipna and empty_targfunc and isna(targartempval).all():
targ = empty_targfunc(targartempval, axis=axis, **kwargs)
else:
targ = targfunc(targartempval, axis=axis, **kwargs)

res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs)
res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)
if skipna:
res = testfunc(testarval, axis=axis, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)
if axis is None:
res = testfunc(testarval, skipna=skipna, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)
if skipna and axis is None:
res = testfunc(testarval, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)
if skipna:
res = testfunc(testarval, axis=axis, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)
if axis is None:
res = testfunc(testarval, skipna=skipna, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)
if skipna and axis is None:
res = testfunc(testarval, **kwargs)
self.check_results(targ, res, axis, check_dtype=check_dtype)

if testarval.ndim <= 1:
return
Expand All @@ -184,12 +176,15 @@ def check_fun_data(
targfunc,
testarval2,
targarval2,
skipna=skipna,
check_dtype=check_dtype,
empty_targfunc=empty_targfunc,
**kwargs,
)

def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs):
def check_fun(
self, testfunc, targfunc, testar, skipna, empty_targfunc=None, **kwargs
):

targar = testar
if testar.endswith("_nan") and hasattr(self, testar[:-4]):
Expand All @@ -202,6 +197,7 @@ def check_fun(self, testfunc, targfunc, testar, empty_targfunc=None, **kwargs):
targfunc,
testarval,
targarval,
skipna=skipna,
empty_targfunc=empty_targfunc,
**kwargs,
)
Expand All @@ -210,36 +206,37 @@ def check_funs(
self,
testfunc,
targfunc,
skipna,
allow_complex=True,
allow_all_nan=True,
allow_date=True,
allow_tdelta=True,
allow_obj=True,
**kwargs,
):
self.check_fun(testfunc, targfunc, "arr_float", **kwargs)
self.check_fun(testfunc, targfunc, "arr_float_nan", **kwargs)
self.check_fun(testfunc, targfunc, "arr_int", **kwargs)
self.check_fun(testfunc, targfunc, "arr_bool", **kwargs)
self.check_fun(testfunc, targfunc, "arr_float", skipna, **kwargs)
self.check_fun(testfunc, targfunc, "arr_float_nan", skipna, **kwargs)
self.check_fun(testfunc, targfunc, "arr_int", skipna, **kwargs)
self.check_fun(testfunc, targfunc, "arr_bool", skipna, **kwargs)
objs = [
self.arr_float.astype("O"),
self.arr_int.astype("O"),
self.arr_bool.astype("O"),
]

if allow_all_nan:
self.check_fun(testfunc, targfunc, "arr_nan", **kwargs)
self.check_fun(testfunc, targfunc, "arr_nan", skipna, **kwargs)

if allow_complex:
self.check_fun(testfunc, targfunc, "arr_complex", **kwargs)
self.check_fun(testfunc, targfunc, "arr_complex_nan", **kwargs)
self.check_fun(testfunc, targfunc, "arr_complex", skipna, **kwargs)
self.check_fun(testfunc, targfunc, "arr_complex_nan", skipna, **kwargs)
if allow_all_nan:
self.check_fun(testfunc, targfunc, "arr_nan_nanj", **kwargs)
self.check_fun(testfunc, targfunc, "arr_nan_nanj", skipna, **kwargs)
objs += [self.arr_complex.astype("O")]

if allow_date:
targfunc(self.arr_date)
self.check_fun(testfunc, targfunc, "arr_date", **kwargs)
self.check_fun(testfunc, targfunc, "arr_date", skipna, **kwargs)
objs += [self.arr_date.astype("O")]

if allow_tdelta:
Expand All @@ -248,7 +245,7 @@ def check_funs(
except TypeError:
pass
else:
self.check_fun(testfunc, targfunc, "arr_tdelta", **kwargs)
self.check_fun(testfunc, targfunc, "arr_tdelta", skipna, **kwargs)
objs += [self.arr_tdelta.astype("O")]

if allow_obj:
Expand All @@ -260,7 +257,7 @@ def check_funs(
targfunc = partial(
self._badobj_wrap, func=targfunc, allow_complex=allow_complex
)
self.check_fun(testfunc, targfunc, "arr_obj", **kwargs)
self.check_fun(testfunc, targfunc, "arr_obj", skipna, **kwargs)

def _badobj_wrap(self, value, func, allow_complex=True, **kwargs):
if value.dtype.kind == "O":
Expand All @@ -273,28 +270,22 @@ def _badobj_wrap(self, value, func, allow_complex=True, **kwargs):
@pytest.mark.parametrize(
"nan_op,np_op", [(nanops.nanany, np.any), (nanops.nanall, np.all)]
)
def test_nan_funcs(self, nan_op, np_op):
# TODO: allow tdelta, doesn't break tests
self.check_funs(
nan_op, np_op, allow_all_nan=False, allow_date=False, allow_tdelta=False
)
def test_nan_funcs(self, nan_op, np_op, skipna):
self.check_funs(nan_op, np_op, skipna, allow_all_nan=False, allow_date=False)

def test_nansum(self):
def test_nansum(self, skipna):
self.check_funs(
nanops.nansum,
np.sum,
skipna,
allow_date=False,
check_dtype=False,
empty_targfunc=np.nansum,
)

def test_nanmean(self):
def test_nanmean(self, skipna):
self.check_funs(
nanops.nanmean,
np.mean,
allow_complex=False, # TODO: allow this, doesn't break test
allow_obj=False,
allow_date=False,
nanops.nanmean, np.mean, skipna, allow_obj=False, allow_date=False,
)

def test_nanmean_overflow(self):
Expand Down Expand Up @@ -336,33 +327,36 @@ def test_returned_dtype(self, dtype):
else:
assert result.dtype == dtype

def test_nanmedian(self):
def test_nanmedian(self, skipna):
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
self.check_funs(
nanops.nanmedian,
np.median,
skipna,
allow_complex=False,
allow_date=False,
allow_obj="convert",
)

@pytest.mark.parametrize("ddof", range(3))
def test_nanvar(self, ddof):
def test_nanvar(self, ddof, skipna):
self.check_funs(
nanops.nanvar,
np.var,
skipna,
allow_complex=False,
allow_date=False,
allow_obj="convert",
ddof=ddof,
)

@pytest.mark.parametrize("ddof", range(3))
def test_nanstd(self, ddof):
def test_nanstd(self, ddof, skipna):
self.check_funs(
nanops.nanstd,
np.std,
skipna,
allow_complex=False,
allow_date=False,
allow_obj="convert",
Expand All @@ -371,13 +365,14 @@ def test_nanstd(self, ddof):

@td.skip_if_no_scipy
@pytest.mark.parametrize("ddof", range(3))
def test_nansem(self, ddof):
def test_nansem(self, ddof, skipna):
from scipy.stats import sem

with np.errstate(invalid="ignore"):
self.check_funs(
nanops.nansem,
sem,
skipna,
allow_complex=False,
allow_date=False,
allow_tdelta=False,
Expand All @@ -388,10 +383,10 @@ def test_nansem(self, ddof):
@pytest.mark.parametrize(
"nan_op,np_op", [(nanops.nanmin, np.min), (nanops.nanmax, np.max)]
)
def test_nanops_with_warnings(self, nan_op, np_op):
def test_nanops_with_warnings(self, nan_op, np_op, skipna):
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
self.check_funs(nan_op, np_op, allow_obj=False)
self.check_funs(nan_op, np_op, skipna, allow_obj=False)

def _argminmax_wrap(self, value, axis=None, func=None):
res = func(value, axis)
Expand All @@ -408,17 +403,17 @@ def _argminmax_wrap(self, value, axis=None, func=None):
res = -1
return res

def test_nanargmax(self):
def test_nanargmax(self, skipna):
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
func = partial(self._argminmax_wrap, func=np.argmax)
self.check_funs(nanops.nanargmax, func, allow_obj=False)
self.check_funs(nanops.nanargmax, func, skipna, allow_obj=False)

def test_nanargmin(self):
def test_nanargmin(self, skipna):
with warnings.catch_warnings(record=True):
warnings.simplefilter("ignore", RuntimeWarning)
func = partial(self._argminmax_wrap, func=np.argmin)
self.check_funs(nanops.nanargmin, func, allow_obj=False)
self.check_funs(nanops.nanargmin, func, skipna, allow_obj=False)

def _skew_kurt_wrap(self, values, axis=None, func=None):
if not isinstance(values.dtype.type, np.floating):
Expand All @@ -433,21 +428,22 @@ def _skew_kurt_wrap(self, values, axis=None, func=None):
return result

@td.skip_if_no_scipy
def test_nanskew(self):
def test_nanskew(self, skipna):
from scipy.stats import skew

func = partial(self._skew_kurt_wrap, func=skew)
with np.errstate(invalid="ignore"):
self.check_funs(
nanops.nanskew,
func,
skipna,
allow_complex=False,
allow_date=False,
allow_tdelta=False,
)

@td.skip_if_no_scipy
def test_nankurt(self):
def test_nankurt(self, skipna):
from scipy.stats import kurtosis

func1 = partial(kurtosis, fisher=True)
Expand All @@ -456,15 +452,17 @@ def test_nankurt(self):
self.check_funs(
nanops.nankurt,
func,
skipna,
allow_complex=False,
allow_date=False,
allow_tdelta=False,
)

def test_nanprod(self):
def test_nanprod(self, skipna):
self.check_funs(
nanops.nanprod,
np.prod,
skipna,
allow_date=False,
allow_tdelta=False,
empty_targfunc=np.nanprod,
Expand Down