1
1
""" miscellaneous sorting / groupby utilities """
2
- from typing import Callable , Dict , Union
2
+ from typing import Callable , Optional
3
3
4
4
import numpy as np
5
5
@@ -299,32 +299,20 @@ def nargsort(
299
299
return indexer
300
300
301
301
302
- def apply_key_name (values , key , name ):
303
- if isinstance (key , dict ):
304
- key = key .get (name , None )
305
-
306
- if key is None :
307
- return values
308
-
309
- return key (values )
310
-
311
-
312
- def ensure_key_mapped_dataframe (df , key , levels = None , axis = 0 ):
302
+ def ensure_key_mapped_dataframe (df , key : Callable , levels = None , axis = 0 ):
313
303
"""
314
304
Returns a new DataFrame in which key has been applied
315
305
to all levels specified in level (or all levels if level
316
- is None). Used for key sorting for DataFrames.
306
+ is None) along an axis . Used for key sorting for DataFrames.
317
307
318
308
Parameters
319
309
----------
320
310
df : DataFrame
321
311
DataFrame to which to apply the key function on the
322
312
specified levels.
323
- key : Callable or Dict[Any, Callable]
324
- If Callable, function that takes a Series and returns
325
- a Series of the same shape. This key is applied to each
326
- level separately. If dict, name or index of each column
327
- or row is used to index the key object to get a Callable.
313
+ key : Callable
314
+ Function that takes a Series and returns a Series of
315
+ the same shape. This key is applied to each level separately.
328
316
levels : list-like, int or str, default None
329
317
Level or list of levels to apply the key function to.
330
318
If None, key function is applied to all levels. Other
@@ -345,9 +333,9 @@ def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
345
333
)
346
334
347
335
if axis == 0 :
348
- axis_levels = df .columns ._values
336
+ axis_levels = list ( df .columns ._values ) # makes mypy happy
349
337
else :
350
- axis_levels = df .index ._values
338
+ axis_levels = list ( df .index ._values )
351
339
352
340
if levels is not None :
353
341
if isinstance (levels , (str , int )):
@@ -357,17 +345,16 @@ def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
357
345
else :
358
346
sort_levels = axis_levels
359
347
360
- new_levels = [
361
- ensure_key_mapped (
362
- Series (df ._get_label_or_level_values (name , axis = axis ), name = name ),
363
- key ,
364
- name = name ,
365
- )
366
- if name in sort_levels
367
- else df ._get_label_or_level_values (name , axis = axis )
348
+ values = [
349
+ (name , Series (df ._get_label_or_level_values (name , axis = axis ), name = name ))
368
350
for name in axis_levels
369
351
]
370
352
353
+ new_levels = [
354
+ ensure_key_mapped (series , key ) if name in sort_levels else series
355
+ for (name , series ) in values
356
+ ]
357
+
371
358
if axis == 0 :
372
359
new_df = DataFrame ._from_arrays (new_levels , df .columns , df .index )
373
360
else :
@@ -419,11 +406,11 @@ def ensure_key_mapped_multiindex(index, key: Callable, levels=None):
419
406
else :
420
407
sort_levels = list (range (index .nlevels )) # satisfies mypy
421
408
409
+ values = [(level , index ._get_level_values (level )) for level in range (index .nlevels )]
410
+
422
411
mapped = [
423
- ensure_key_mapped (index ._get_level_values (level ), key , name = level )
424
- if level in sort_levels
425
- else index ._get_level_values (level )
426
- for level in range (index .nlevels )
412
+ ensure_key_mapped (idx , key ) if level in sort_levels else idx
413
+ for (level , idx ) in values
427
414
]
428
415
429
416
labels = MultiIndex .from_arrays (mapped )
@@ -432,7 +419,7 @@ def ensure_key_mapped_multiindex(index, key: Callable, levels=None):
432
419
433
420
434
421
def ensure_key_mapped (
435
- values , key : Union [ Dict , Callable ], levels = None , name = None , axis = 0 ,
422
+ values , key : Optional [ Callable ], levels = None , axis = 0 ,
436
423
):
437
424
438
425
"""
@@ -443,19 +430,19 @@ def ensure_key_mapped(
443
430
Parameters
444
431
----------
445
432
values : Series, DataFrame, Index subclass, or ndarray
446
- key : Union[Callable, Dict[Union[str, int], Callable]
447
- key to be called on the values array. If dict, indexed by
448
- name or number of columns in values. Dict supported only for
449
- DataFrame or MultiIndex. Expected to return an object of the
450
- same shape and compatible with the original type.
433
+ key : Optional[Callable]
434
+ key to be called on the values array. Expected to
435
+ take a level or column from values and return an
436
+ object of the same shape and compatible with the original type.
437
+ For MultiIndex and DataFrame, applied to rows or columns of the
438
+ values array. For Series and MultiIndex, applied directly.
451
439
levels : list-like, int or str, default None
452
440
For MultiIndex values, level or list of levels to apply the key
453
441
function to. If None, key function is applied to all levels. Other
454
442
levels are left unchanged.
455
- name : str or int, default None
456
- Name used to index the key function if a dictionary.
457
443
axis : int, default 0
458
- Axis to use for applying the key to DataFrame values level by level.
444
+ Axis to use for applying the key to DataFrame values. 0 applies the
445
+ key to columns, 1 to rows.
459
446
"""
460
447
from pandas .core .indexes .api import Index
461
448
from pandas import DataFrame
@@ -469,18 +456,18 @@ def ensure_key_mapped(
469
456
if isinstance (values , DataFrame ): # apply the key to select levels
470
457
return ensure_key_mapped_dataframe (values , key , levels = levels , axis = axis )
471
458
472
- result = apply_key_name (values .copy (), key , name )
459
+ result = key (values .copy ())
473
460
474
461
if len (result ) != len (values ):
475
462
raise ValueError (
476
463
"User-provided `key` function must not change the shape of the array."
477
464
)
478
465
479
466
try :
480
- if isinstance (values , Index ):
467
+ if isinstance (values , Index ): # allow a new Index class
481
468
result = Index (result )
482
469
else :
483
- result = type (values )(result )
470
+ result = type (values )(result ) # try to recover otherwise
484
471
except TypeError :
485
472
raise TypeError (
486
473
"User-provided `key` function returned an invalid type {} \
0 commit comments