-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Fix nsmallest/nlargest With Identical Values #15299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -931,6 +931,14 @@ def select_n_slow(dropped, n, keep, method): | |
_select_methods = {'nsmallest': nsmallest, 'nlargest': nlargest} | ||
|
||
|
||
def _is_valid_dtype_n_method(dtype): | ||
"""Helper function to determine if `dtpye` is valid for | ||
nsmallest/nlargest methods | ||
""" | ||
return ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or | ||
needs_i8_conversion(dtype)) | ||
|
||
|
||
def select_n_series(series, n, keep, method): | ||
"""Implement n largest/smallest for pandas Series | ||
|
||
|
@@ -946,8 +954,7 @@ def select_n_series(series, n, keep, method): | |
nordered : Series | ||
""" | ||
dtype = series.dtype | ||
if not ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or | ||
needs_i8_conversion(dtype)): | ||
if not _is_valid_dtype_n_method(dtype): | ||
raise TypeError("Cannot use method '{method}' with " | ||
"dtype {dtype}".format(method=method, dtype=dtype)) | ||
|
||
|
@@ -981,14 +988,67 @@ def select_n_frame(frame, columns, n, method, keep): | |
------- | ||
nordered : DataFrame | ||
""" | ||
from pandas.core.series import Series | ||
from pandas import Int64Index | ||
if not is_list_like(columns): | ||
columns = [columns] | ||
columns = list(columns) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to do this check on all the columns upfront (its cheap). you can put this in a separate function on core/algorithms (called by there and here): https://github.com/pandas-dev/pandas/blob/master/pandas/core/algorithms.py#L949 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do u want to reference the column in the exception? IOW def _ensure_is_valid_dtype_n_method(method, dtype):
if not ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
needs_i8_conversion(dtype)):
raise TypeError("Cannot use method '{method}' with "
"dtype {dtype}".format(method=method, dtype=dtype)) or something like def _is_valid_dtype_n_method(dtype):
return ((is_numeric_dtype(dtype) and not is_complex_dtype(dtype)) or
needs_i8_conversion(dtype)): and then in select_n_frame function something like if not _is_valid_dtype_n_method(dtype):
raise TypeError("Column `column` has dtype `dtype` can't use n with column") and then in select_n_series function something like if not _is_valid_dtype_n_method(dtype):
raise TypeError("Cannot use method `method` with dtype") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
then you can have separate messages for a series calling this (alone) and a dataframe with many columns. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. though keep the message for Series pretty much the same (actually it IS tested). |
||
ser = getattr(frame[columns[0]], method)(n, keep=keep) | ||
if isinstance(ser, Series): | ||
ser = ser.to_frame() | ||
return ser.merge(frame, on=columns[0], left_index=True)[frame.columns] | ||
for column in columns: | ||
dtype = frame[column].dtype | ||
if not _is_valid_dtype_n_method(dtype): | ||
raise TypeError(( | ||
"Column {column!r} has dtype {dtype}, cannot use method " | ||
"{method!r} with this dtype" | ||
).format(column=column, dtype=dtype, method=method)) | ||
|
||
def get_indexer(current_indexer, other_indexer): | ||
"""Helper function to concat `current_indexer` and `other_indexer` | ||
depending on `method` | ||
""" | ||
if method == 'nsmallest': | ||
return current_indexer.append(other_indexer) | ||
else: | ||
return other_indexer.append(current_indexer) | ||
|
||
# Below we save and reset the index in case index contains duplicates | ||
original_index = frame.index | ||
cur_frame = frame = frame.reset_index(drop=True) | ||
cur_n = n | ||
indexer = Int64Index([]) | ||
|
||
for i, column in enumerate(columns): | ||
|
||
# For each column we apply method to cur_frame[column]. If it is the | ||
# last column in columns, or if the values returned are unique in | ||
# frame[column] we save this index and break | ||
# Otherwise we must save the index of the non duplicated values | ||
# and set the next cur_frame to cur_frame filtered on all duplcicated | ||
# values (#GH15297) | ||
series = cur_frame[column] | ||
values = getattr(series, method)(cur_n, keep=keep) | ||
is_last_column = len(columns) - 1 == i | ||
if is_last_column or values.nunique() == series.isin(values).sum(): | ||
|
||
# Last column in columns or values are unique in series => values | ||
# is all that matters | ||
indexer = get_indexer(indexer, values.index) | ||
break | ||
|
||
duplicated_filter = series.duplicated(keep=False) | ||
duplicated = values[duplicated_filter] | ||
non_duplicated = values[~duplicated_filter] | ||
indexer = get_indexer(indexer, non_duplicated.index) | ||
|
||
# Must set cur frame to include all duplicated values to consider for | ||
# the next column, we also can reduce cur_n by the current length of | ||
# the indexer | ||
cur_frame = cur_frame[series.isin(duplicated)] | ||
cur_n = n - len(indexer) | ||
|
||
frame = frame.take(indexer) | ||
|
||
# Restore the index on frame | ||
frame.index = original_index.take(indexer) | ||
return frame | ||
|
||
|
||
def _finalize_nsmallest(arr, kth_val, n, keep, narr): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,11 +7,12 @@ | |
import sys | ||
import pytest | ||
|
||
from string import ascii_lowercase | ||
from numpy import nan | ||
from numpy.random import randn | ||
import numpy as np | ||
|
||
from pandas.compat import lrange | ||
from pandas.compat import lrange, product | ||
from pandas import (compat, isnull, notnull, DataFrame, Series, | ||
MultiIndex, date_range, Timestamp) | ||
import pandas as pd | ||
|
@@ -1120,73 +1121,6 @@ def __nonzero__(self): | |
self.assertTrue(r1.all()) | ||
|
||
# ---------------------------------------------------------------------- | ||
# Top / bottom | ||
|
||
def test_nlargest(self): | ||
# GH10393 | ||
from string import ascii_lowercase | ||
df = pd.DataFrame({'a': np.random.permutation(10), | ||
'b': list(ascii_lowercase[:10])}) | ||
result = df.nlargest(5, 'a') | ||
expected = df.sort_values('a', ascending=False).head(5) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_nlargest_multiple_columns(self): | ||
from string import ascii_lowercase | ||
df = pd.DataFrame({'a': np.random.permutation(10), | ||
'b': list(ascii_lowercase[:10]), | ||
'c': np.random.permutation(10).astype('float64')}) | ||
result = df.nlargest(5, ['a', 'b']) | ||
expected = df.sort_values(['a', 'b'], ascending=False).head(5) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_nsmallest(self): | ||
from string import ascii_lowercase | ||
df = pd.DataFrame({'a': np.random.permutation(10), | ||
'b': list(ascii_lowercase[:10])}) | ||
result = df.nsmallest(5, 'a') | ||
expected = df.sort_values('a').head(5) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_nsmallest_multiple_columns(self): | ||
from string import ascii_lowercase | ||
df = pd.DataFrame({'a': np.random.permutation(10), | ||
'b': list(ascii_lowercase[:10]), | ||
'c': np.random.permutation(10).astype('float64')}) | ||
result = df.nsmallest(5, ['a', 'c']) | ||
expected = df.sort_values(['a', 'c']).head(5) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
def test_nsmallest_nlargest_duplicate_index(self): | ||
# GH 13412 | ||
df = pd.DataFrame({'a': [1, 2, 3, 4], | ||
'b': [4, 3, 2, 1], | ||
'c': [0, 1, 2, 3]}, | ||
index=[0, 0, 1, 1]) | ||
result = df.nsmallest(4, 'a') | ||
expected = df.sort_values('a').head(4) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nlargest(4, 'a') | ||
expected = df.sort_values('a', ascending=False).head(4) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nsmallest(4, ['a', 'c']) | ||
expected = df.sort_values(['a', 'c']).head(4) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nsmallest(4, ['c', 'a']) | ||
expected = df.sort_values(['c', 'a']).head(4) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nlargest(4, ['a', 'c']) | ||
expected = df.sort_values(['a', 'c'], ascending=False).head(4) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nlargest(4, ['c', 'a']) | ||
expected = df.sort_values(['c', 'a'], ascending=False).head(4) | ||
tm.assert_frame_equal(result, expected) | ||
# ---------------------------------------------------------------------- | ||
# Isin | ||
|
||
def test_isin(self): | ||
|
@@ -1965,3 +1899,132 @@ def test_dot(self): | |
|
||
with tm.assertRaisesRegexp(ValueError, 'aligned'): | ||
df.dot(df2) | ||
|
||
|
||
@pytest.fixture | ||
def df_duplicates(): | ||
return pd.DataFrame({'a': [1, 2, 3, 4, 4], | ||
'b': [1, 1, 1, 1, 1], | ||
'c': [0, 1, 2, 5, 4]}, | ||
index=[0, 0, 1, 1, 1]) | ||
|
||
|
||
@pytest.fixture | ||
def df_strings(): | ||
return pd.DataFrame({'a': np.random.permutation(10), | ||
'b': list(ascii_lowercase[:10]), | ||
'c': np.random.permutation(10).astype('float64')}) | ||
|
||
|
||
@pytest.fixture | ||
def df_main_dtypes(): | ||
return pd.DataFrame( | ||
{'group': [1, 1, 2], | ||
'int': [1, 2, 3], | ||
'float': [4., 5., 6.], | ||
'string': list('abc'), | ||
'category_string': pd.Series(list('abc')).astype('category'), | ||
'category_int': [7, 8, 9], | ||
'datetime': pd.date_range('20130101', periods=3), | ||
'datetimetz': pd.date_range('20130101', | ||
periods=3, | ||
tz='US/Eastern'), | ||
'timedelta': pd.timedelta_range('1 s', periods=3, freq='s')}, | ||
columns=['group', 'int', 'float', 'string', | ||
'category_string', 'category_int', | ||
'datetime', 'datetimetz', | ||
'timedelta']) | ||
|
||
|
||
class TestNLargestNSmallest(object): | ||
|
||
dtype_error_msg_template = ("Column {column!r} has dtype {dtype}, cannot " | ||
"use method {method!r} with this dtype") | ||
|
||
# ---------------------------------------------------------------------- | ||
# Top / bottom | ||
@pytest.mark.parametrize( | ||
'method, n, order', | ||
product(['nsmallest', 'nlargest'], range(1, 11), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think adding method doesn't add much here (and instead add both as comparision tests below) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but after seeing your other methods, maybe ok. if its easier to debug then ok. |
||
[['a'], | ||
['c'], | ||
['a', 'b'], | ||
['a', 'c'], | ||
['b', 'a'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @RogerThomas you will see this test failing for the below. The issues is that if one of the columns you are n* on is object this will fail. So I think that should raise a nice error message (you can check if any are object upfront).
|
||
['b', 'c'], | ||
['a', 'b', 'c'], | ||
['c', 'a', 'b'], | ||
['c', 'b', 'a'], | ||
['b', 'c', 'a'], | ||
['b', 'a', 'c'], | ||
|
||
# dups! | ||
['b', 'c', 'c'], | ||
|
||
])) | ||
def test_n(self, df_strings, method, n, order): | ||
# GH10393 | ||
df = df_strings | ||
if 'b' in order: | ||
|
||
error_msg = self.dtype_error_msg_template.format( | ||
column='b', method=method, dtype='object') | ||
with tm.assertRaisesRegexp(TypeError, error_msg): | ||
getattr(df, method)(n, order) | ||
else: | ||
ascending = method == 'nsmallest' | ||
result = getattr(df, method)(n, order) | ||
expected = df.sort_values(order, ascending=ascending).head(n) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
@pytest.mark.parametrize( | ||
'method, columns', | ||
product(['nsmallest', 'nlargest'], | ||
product(['group'], ['category_string', 'string']) | ||
)) | ||
def test_n_error(self, df_main_dtypes, method, columns): | ||
df = df_main_dtypes | ||
error_msg = self.dtype_error_msg_template.format( | ||
column=columns[1], method=method, dtype=df[columns[1]].dtype) | ||
with tm.assertRaisesRegexp(TypeError, error_msg): | ||
getattr(df, method)(2, columns) | ||
|
||
def test_n_all_dtypes(self, df_main_dtypes): | ||
df = df_main_dtypes | ||
df.nsmallest(2, list(set(df) - {'category_string', 'string'})) | ||
df.nlargest(2, list(set(df) - {'category_string', 'string'})) | ||
|
||
def test_n_identical_values(self): | ||
# GH15297 | ||
df = pd.DataFrame({'a': [1] * 5, 'b': [1, 2, 3, 4, 5]}) | ||
|
||
result = df.nlargest(3, 'a') | ||
expected = pd.DataFrame( | ||
{'a': [1] * 3, 'b': [1, 2, 3]}, index=[0, 1, 2] | ||
) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nsmallest(3, 'a') | ||
expected = pd.DataFrame({'a': [1] * 3, 'b': [1, 2, 3]}) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
@pytest.mark.parametrize( | ||
'n, order', | ||
product([1, 2, 3, 4, 5], | ||
[['a', 'b', 'c'], | ||
['c', 'b', 'a'], | ||
['a'], | ||
['b'], | ||
['a', 'b'], | ||
['c', 'b']])) | ||
def test_n_duplicate_index(self, df_duplicates, n, order): | ||
# GH 13412 | ||
|
||
df = df_duplicates | ||
result = df.nsmallest(n, order) | ||
expected = df.sort_values(order).head(n) | ||
tm.assert_frame_equal(result, expected) | ||
|
||
result = df.nlargest(n, order) | ||
expected = df.sort_values(order, ascending=False).head(n) | ||
tm.assert_frame_equal(result, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jreback I only put this import here as the previous algorithm imported Series here, why do you do inline imports?
I would always have all my imports at the top
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because some of the routines from algorithms are imported at the very top level (e.g.
pd.factorize
and you would get circular imports otherwise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AH ok, thanks!