Skip to content

Commit dbc1654

Browse files
committed
TST: better testing of Series.nlargest/nsmallest
xref #15299 Author: Jeff Reback <[email protected]> Closes #15902 from jreback/series_n and squashes the following commits: 657eac8 [Jeff Reback] TST: better testing of Series.nlargest/nsmallest
1 parent 0a37067 commit dbc1654

File tree

2 files changed

+151
-85
lines changed

2 files changed

+151
-85
lines changed

pandas/core/algorithms.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pandas.types.common import (is_unsigned_integer_dtype,
1313
is_signed_integer_dtype,
1414
is_integer_dtype,
15+
is_complex_dtype,
1516
is_categorical_dtype,
1617
is_extension_type,
1718
is_datetimetz,
@@ -40,6 +41,44 @@
4041
from pandas._libs.tslib import iNaT
4142

4243

44+
# --------------- #
45+
# dtype access #
46+
# --------------- #
47+
48+
def _ensure_data_view(values):
49+
"""
50+
helper routine to ensure that our data is of the correct
51+
input dtype for lower-level routines
52+
53+
Parameters
54+
----------
55+
values : array-like
56+
"""
57+
58+
if needs_i8_conversion(values):
59+
values = values.view(np.int64)
60+
elif is_period_arraylike(values):
61+
from pandas.tseries.period import PeriodIndex
62+
values = PeriodIndex(values).asi8
63+
elif is_categorical_dtype(values):
64+
values = values.values.codes
65+
elif isinstance(values, (ABCSeries, ABCIndex)):
66+
values = values.values
67+
68+
if is_signed_integer_dtype(values):
69+
values = _ensure_int64(values)
70+
elif is_unsigned_integer_dtype(values):
71+
values = _ensure_uint64(values)
72+
elif is_complex_dtype(values):
73+
values = _ensure_float64(values)
74+
elif is_float_dtype(values):
75+
values = _ensure_float64(values)
76+
else:
77+
values = _ensure_object(values)
78+
79+
return values
80+
81+
4382
# --------------- #
4483
# top-level algos #
4584
# --------------- #
@@ -867,9 +906,7 @@ def nsmallest(arr, n, keep='first'):
867906
narr = len(arr)
868907
n = min(n, narr)
869908

870-
sdtype = str(arr.dtype)
871-
arr = arr.view(_dtype_map.get(sdtype, sdtype))
872-
909+
arr = _ensure_data_view(arr)
873910
kth_val = algos.kth_smallest(arr.copy(), n - 1)
874911
return _finalize_nsmallest(arr, kth_val, n, keep, narr)
875912

@@ -880,8 +917,7 @@ def nlargest(arr, n, keep='first'):
880917
881918
Note: Fails silently with NaN.
882919
"""
883-
sdtype = str(arr.dtype)
884-
arr = arr.view(_dtype_map.get(sdtype, sdtype))
920+
arr = _ensure_data_view(arr)
885921
return nsmallest(-arr, n, keep=keep)
886922

887923

@@ -910,9 +946,10 @@ def select_n_series(series, n, keep, method):
910946
nordered : Series
911947
"""
912948
dtype = series.dtype
913-
if not issubclass(dtype.type, (np.integer, np.floating, np.datetime64,
914-
np.timedelta64)):
915-
raise TypeError("Cannot use method %r with dtype %s" % (method, dtype))
949+
if not ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
950+
needs_i8_conversion(dtype)):
951+
raise TypeError("Cannot use method '{method}' with "
952+
"dtype {dtype}".format(method=method, dtype=dtype))
916953

917954
if keep not in ('first', 'last'):
918955
raise ValueError('keep must be either "first", "last"')
@@ -964,9 +1001,6 @@ def _finalize_nsmallest(arr, kth_val, n, keep, narr):
9641001
return inds
9651002

9661003

967-
_dtype_map = {'datetime64[ns]': 'int64', 'timedelta64[ns]': 'int64'}
968-
969-
9701004
# ------- #
9711005
# helpers #
9721006
# ------- #

pandas/tests/series/test_analytics.py

+106-74
Original file line numberDiff line numberDiff line change
@@ -1381,80 +1381,6 @@ def test_is_monotonic(self):
13811381
self.assertFalse(s.is_monotonic)
13821382
self.assertTrue(s.is_monotonic_decreasing)
13831383

1384-
def test_nsmallest_nlargest(self):
1385-
# float, int, datetime64 (use i8), timedelts64 (same),
1386-
# object that are numbers, object that are strings
1387-
1388-
base = [3, 2, 1, 2, 5]
1389-
1390-
s_list = [
1391-
Series(base, dtype='int8'),
1392-
Series(base, dtype='int16'),
1393-
Series(base, dtype='int32'),
1394-
Series(base, dtype='int64'),
1395-
Series(base, dtype='float32'),
1396-
Series(base, dtype='float64'),
1397-
Series(base, dtype='uint8'),
1398-
Series(base, dtype='uint16'),
1399-
Series(base, dtype='uint32'),
1400-
Series(base, dtype='uint64'),
1401-
Series(base).astype('timedelta64[ns]'),
1402-
Series(pd.to_datetime(['2003', '2002', '2001', '2002', '2005'])),
1403-
]
1404-
1405-
raising = [
1406-
Series([3., 2, 1, 2, '5'], dtype='object'),
1407-
Series([3., 2, 1, 2, 5], dtype='object'),
1408-
# not supported on some archs
1409-
# Series([3., 2, 1, 2, 5], dtype='complex256'),
1410-
Series([3., 2, 1, 2, 5], dtype='complex128'),
1411-
]
1412-
1413-
for r in raising:
1414-
dt = r.dtype
1415-
msg = "Cannot use method 'n(larg|small)est' with dtype %s" % dt
1416-
args = 2, len(r), 0, -1
1417-
methods = r.nlargest, r.nsmallest
1418-
for method, arg in product(methods, args):
1419-
with tm.assertRaisesRegexp(TypeError, msg):
1420-
method(arg)
1421-
1422-
for s in s_list:
1423-
1424-
assert_series_equal(s.nsmallest(2), s.iloc[[2, 1]])
1425-
assert_series_equal(s.nsmallest(2, keep='last'), s.iloc[[2, 3]])
1426-
1427-
empty = s.iloc[0:0]
1428-
assert_series_equal(s.nsmallest(0), empty)
1429-
assert_series_equal(s.nsmallest(-1), empty)
1430-
assert_series_equal(s.nlargest(0), empty)
1431-
assert_series_equal(s.nlargest(-1), empty)
1432-
1433-
assert_series_equal(s.nsmallest(len(s)), s.sort_values())
1434-
assert_series_equal(s.nsmallest(len(s) + 1), s.sort_values())
1435-
assert_series_equal(s.nlargest(len(s)), s.iloc[[4, 0, 1, 3, 2]])
1436-
assert_series_equal(s.nlargest(len(s) + 1),
1437-
s.iloc[[4, 0, 1, 3, 2]])
1438-
1439-
s = Series([3., np.nan, 1, 2, 5])
1440-
assert_series_equal(s.nlargest(), s.iloc[[4, 0, 3, 2]])
1441-
assert_series_equal(s.nsmallest(), s.iloc[[2, 3, 0, 4]])
1442-
1443-
msg = 'keep must be either "first", "last"'
1444-
with tm.assertRaisesRegexp(ValueError, msg):
1445-
s.nsmallest(keep='invalid')
1446-
with tm.assertRaisesRegexp(ValueError, msg):
1447-
s.nlargest(keep='invalid')
1448-
1449-
# GH 13412
1450-
s = Series([1, 4, 3, 2], index=[0, 0, 1, 1])
1451-
result = s.nlargest(3)
1452-
expected = s.sort_values(ascending=False).head(3)
1453-
assert_series_equal(result, expected)
1454-
result = s.nsmallest(3)
1455-
expected = s.sort_values().head(3)
1456-
assert_series_equal(result, expected)
1457-
14581384
def test_sort_index_level(self):
14591385
mi = MultiIndex.from_tuples([[1, 1, 3], [1, 1, 1]], names=list('ABC'))
14601386
s = Series([1, 2], mi)
@@ -1729,3 +1655,109 @@ def test_value_counts_categorical_not_ordered(self):
17291655
index=exp_idx, name='xxx')
17301656
tm.assert_series_equal(s.value_counts(normalize=True), exp)
17311657
tm.assert_series_equal(idx.value_counts(normalize=True), exp)
1658+
1659+
1660+
@pytest.fixture
1661+
def s_main_dtypes():
1662+
df = pd.DataFrame(
1663+
{'datetime': pd.to_datetime(['2003', '2002',
1664+
'2001', '2002',
1665+
'2005']),
1666+
'datetimetz': pd.to_datetime(
1667+
['2003', '2002',
1668+
'2001', '2002',
1669+
'2005']).tz_localize('US/Eastern'),
1670+
'timedelta': pd.to_timedelta(['3d', '2d', '1d',
1671+
'2d', '5d'])})
1672+
1673+
for dtype in ['int8', 'int16', 'int32', 'int64',
1674+
'float32', 'float64',
1675+
'uint8', 'uint16', 'uint32', 'uint64']:
1676+
df[dtype] = Series([3, 2, 1, 2, 5], dtype=dtype)
1677+
1678+
return df
1679+
1680+
1681+
class TestNLargestNSmallest(object):
1682+
1683+
@pytest.mark.parametrize(
1684+
"r", [Series([3., 2, 1, 2, '5'], dtype='object'),
1685+
Series([3., 2, 1, 2, 5], dtype='object'),
1686+
# not supported on some archs
1687+
# Series([3., 2, 1, 2, 5], dtype='complex256'),
1688+
Series([3., 2, 1, 2, 5], dtype='complex128'),
1689+
Series(list('abcde'), dtype='category'),
1690+
Series(list('abcde'))])
1691+
def test_error(self, r):
1692+
dt = r.dtype
1693+
msg = ("Cannot use method 'n(larg|small)est' with "
1694+
"dtype {dt}".format(dt=dt))
1695+
args = 2, len(r), 0, -1
1696+
methods = r.nlargest, r.nsmallest
1697+
for method, arg in product(methods, args):
1698+
with tm.assertRaisesRegexp(TypeError, msg):
1699+
method(arg)
1700+
1701+
@pytest.mark.parametrize(
1702+
"s",
1703+
[v for k, v in s_main_dtypes().iteritems()])
1704+
def test_nsmallest_nlargest(self, s):
1705+
# float, int, datetime64 (use i8), timedelts64 (same),
1706+
# object that are numbers, object that are strings
1707+
1708+
assert_series_equal(s.nsmallest(2), s.iloc[[2, 1]])
1709+
assert_series_equal(s.nsmallest(2, keep='last'), s.iloc[[2, 3]])
1710+
1711+
empty = s.iloc[0:0]
1712+
assert_series_equal(s.nsmallest(0), empty)
1713+
assert_series_equal(s.nsmallest(-1), empty)
1714+
assert_series_equal(s.nlargest(0), empty)
1715+
assert_series_equal(s.nlargest(-1), empty)
1716+
1717+
assert_series_equal(s.nsmallest(len(s)), s.sort_values())
1718+
assert_series_equal(s.nsmallest(len(s) + 1), s.sort_values())
1719+
assert_series_equal(s.nlargest(len(s)), s.iloc[[4, 0, 1, 3, 2]])
1720+
assert_series_equal(s.nlargest(len(s) + 1),
1721+
s.iloc[[4, 0, 1, 3, 2]])
1722+
1723+
def test_misc(self):
1724+
1725+
s = Series([3., np.nan, 1, 2, 5])
1726+
assert_series_equal(s.nlargest(), s.iloc[[4, 0, 3, 2]])
1727+
assert_series_equal(s.nsmallest(), s.iloc[[2, 3, 0, 4]])
1728+
1729+
msg = 'keep must be either "first", "last"'
1730+
with tm.assertRaisesRegexp(ValueError, msg):
1731+
s.nsmallest(keep='invalid')
1732+
with tm.assertRaisesRegexp(ValueError, msg):
1733+
s.nlargest(keep='invalid')
1734+
1735+
# GH 15297
1736+
s = Series([1] * 5, index=[1, 2, 3, 4, 5])
1737+
expected_first = Series([1] * 3, index=[1, 2, 3])
1738+
expected_last = Series([1] * 3, index=[5, 4, 3])
1739+
1740+
result = s.nsmallest(3)
1741+
assert_series_equal(result, expected_first)
1742+
1743+
result = s.nsmallest(3, keep='last')
1744+
assert_series_equal(result, expected_last)
1745+
1746+
result = s.nlargest(3)
1747+
assert_series_equal(result, expected_first)
1748+
1749+
result = s.nlargest(3, keep='last')
1750+
assert_series_equal(result, expected_last)
1751+
1752+
@pytest.mark.parametrize('n', range(1, 5))
1753+
def test_n(self, n):
1754+
1755+
# GH 13412
1756+
s = Series([1, 4, 3, 2], index=[0, 0, 1, 1])
1757+
result = s.nlargest(n)
1758+
expected = s.sort_values(ascending=False).head(n)
1759+
assert_series_equal(result, expected)
1760+
1761+
result = s.nsmallest(n)
1762+
expected = s.sort_values().head(n)
1763+
assert_series_equal(result, expected)

0 commit comments

Comments
 (0)