12
12
from numpy .random import randn
13
13
import numpy as np
14
14
15
- from pandas .compat import lrange , product , PY35
15
+ from pandas .compat import lrange , PY35
16
16
from pandas import (compat , isna , notna , DataFrame , Series ,
17
17
MultiIndex , date_range , Timestamp , Categorical ,
18
18
_np_version_under1p12 , _np_version_under1p15 ,
@@ -2260,54 +2260,49 @@ class TestNLargestNSmallest(object):
2260
2260
2261
2261
# ----------------------------------------------------------------------
2262
2262
# Top / bottom
2263
- @pytest .mark .parametrize (
2264
- 'method, n, order' ,
2265
- product (['nsmallest' , 'nlargest' ], range (1 , 11 ),
2266
- [['a' ],
2267
- ['c' ],
2268
- ['a' , 'b' ],
2269
- ['a' , 'c' ],
2270
- ['b' , 'a' ],
2271
- ['b' , 'c' ],
2272
- ['a' , 'b' , 'c' ],
2273
- ['c' , 'a' , 'b' ],
2274
- ['c' , 'b' , 'a' ],
2275
- ['b' , 'c' , 'a' ],
2276
- ['b' , 'a' , 'c' ],
2277
-
2278
- # dups!
2279
- ['b' , 'c' , 'c' ],
2280
-
2281
- ]))
2282
- def test_n (self , df_strings , method , n , order ):
2263
+ @pytest .mark .parametrize ('order' , [
2264
+ ['a' ],
2265
+ ['c' ],
2266
+ ['a' , 'b' ],
2267
+ ['a' , 'c' ],
2268
+ ['b' , 'a' ],
2269
+ ['b' , 'c' ],
2270
+ ['a' , 'b' , 'c' ],
2271
+ ['c' , 'a' , 'b' ],
2272
+ ['c' , 'b' , 'a' ],
2273
+ ['b' , 'c' , 'a' ],
2274
+ ['b' , 'a' , 'c' ],
2275
+
2276
+ # dups!
2277
+ ['b' , 'c' , 'c' ]])
2278
+ @pytest .mark .parametrize ('n' , range (1 , 11 ))
2279
+ def test_n (self , df_strings , nselect_method , n , order ):
2283
2280
# GH10393
2284
2281
df = df_strings
2285
2282
if 'b' in order :
2286
2283
2287
2284
error_msg = self .dtype_error_msg_template .format (
2288
- column = 'b' , method = method , dtype = 'object' )
2285
+ column = 'b' , method = nselect_method , dtype = 'object' )
2289
2286
with tm .assert_raises_regex (TypeError , error_msg ):
2290
- getattr (df , method )(n , order )
2287
+ getattr (df , nselect_method )(n , order )
2291
2288
else :
2292
- ascending = method == 'nsmallest'
2293
- result = getattr (df , method )(n , order )
2289
+ ascending = nselect_method == 'nsmallest'
2290
+ result = getattr (df , nselect_method )(n , order )
2294
2291
expected = df .sort_values (order , ascending = ascending ).head (n )
2295
2292
tm .assert_frame_equal (result , expected )
2296
2293
2297
- @pytest .mark .parametrize (
2298
- 'method, columns' ,
2299
- product (['nsmallest' , 'nlargest' ],
2300
- product (['group' ], ['category_string' , 'string' ])
2301
- ))
2302
- def test_n_error (self , df_main_dtypes , method , columns ):
2294
+ @pytest .mark .parametrize ('columns' , [
2295
+ ('group' , 'category_string' ), ('group' , 'string' )])
2296
+ def test_n_error (self , df_main_dtypes , nselect_method , columns ):
2303
2297
df = df_main_dtypes
2298
+ col = columns [1 ]
2304
2299
error_msg = self .dtype_error_msg_template .format (
2305
- column = columns [ 1 ] , method = method , dtype = df [columns [ 1 ] ].dtype )
2300
+ column = col , method = nselect_method , dtype = df [col ].dtype )
2306
2301
# escape some characters that may be in the repr
2307
2302
error_msg = (error_msg .replace ('(' , '\\ (' ).replace (")" , "\\ )" )
2308
2303
.replace ("[" , "\\ [" ).replace ("]" , "\\ ]" ))
2309
2304
with tm .assert_raises_regex (TypeError , error_msg ):
2310
- getattr (df , method )(2 , columns )
2305
+ getattr (df , nselect_method )(2 , columns )
2311
2306
2312
2307
def test_n_all_dtypes (self , df_main_dtypes ):
2313
2308
df = df_main_dtypes
@@ -2328,15 +2323,14 @@ def test_n_identical_values(self):
2328
2323
expected = pd .DataFrame ({'a' : [1 ] * 3 , 'b' : [1 , 2 , 3 ]})
2329
2324
tm .assert_frame_equal (result , expected )
2330
2325
2331
- @pytest .mark .parametrize (
2332
- 'n, order' ,
2333
- product ([1 , 2 , 3 , 4 , 5 ],
2334
- [['a' , 'b' , 'c' ],
2335
- ['c' , 'b' , 'a' ],
2336
- ['a' ],
2337
- ['b' ],
2338
- ['a' , 'b' ],
2339
- ['c' , 'b' ]]))
2326
+ @pytest .mark .parametrize ('order' , [
2327
+ ['a' , 'b' , 'c' ],
2328
+ ['c' , 'b' , 'a' ],
2329
+ ['a' ],
2330
+ ['b' ],
2331
+ ['a' , 'b' ],
2332
+ ['c' , 'b' ]])
2333
+ @pytest .mark .parametrize ('n' , range (1 , 6 ))
2340
2334
def test_n_duplicate_index (self , df_duplicates , n , order ):
2341
2335
# GH 13412
2342
2336
0 commit comments