Skip to content

Commit 35cc2b3

Browse files
temporarily reverted dict keys
1 parent ad6b188 commit 35cc2b3

File tree

4 files changed

+36
-73
lines changed

4 files changed

+36
-73
lines changed

pandas/core/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3052,7 +3052,8 @@ def _try_kind_sort(arr):
30523052
ser = ensure_key_mapped(self, key=key)
30533053
sorted_index = np.empty(len(self), dtype=np.int32)
30543054

3055-
bad = isna(ser._values)
3055+
sort_values = ser._values
3056+
bad = isna(sort_values)
30563057

30573058
good = ~bad
30583059
idx = ibase.default_index(len(self))

pandas/core/sorting.py

+31-44
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
""" miscellaneous sorting / groupby utilities """
2-
from typing import Callable, Dict, Union
2+
from typing import Callable, Optional
33

44
import numpy as np
55

@@ -299,32 +299,20 @@ def nargsort(
299299
return indexer
300300

301301

302-
def apply_key_name(values, key, name):
303-
if isinstance(key, dict):
304-
key = key.get(name, None)
305-
306-
if key is None:
307-
return values
308-
309-
return key(values)
310-
311-
312-
def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
302+
def ensure_key_mapped_dataframe(df, key: Callable, levels=None, axis=0):
313303
"""
314304
Returns a new DataFrame in which key has been applied
315305
to all levels specified in level (or all levels if level
316-
is None). Used for key sorting for DataFrames.
306+
is None) along an axis. Used for key sorting for DataFrames.
317307
318308
Parameters
319309
----------
320310
df : DataFrame
321311
DataFrame to which to apply the key function on the
322312
specified levels.
323-
key : Callable or Dict[Any, Callable]
324-
If Callable, function that takes a Series and returns
325-
a Series of the same shape. This key is applied to each
326-
level separately. If dict, name or index of each column
327-
or row is used to index the key object to get a Callable.
313+
key : Callable
314+
Function that takes a Series and returns a Series of
315+
the same shape. This key is applied to each level separately.
328316
levels : list-like, int or str, default None
329317
Level or list of levels to apply the key function to.
330318
If None, key function is applied to all levels. Other
@@ -345,9 +333,9 @@ def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
345333
)
346334

347335
if axis == 0:
348-
axis_levels = df.columns._values
336+
axis_levels = list(df.columns._values) # makes mypy happy
349337
else:
350-
axis_levels = df.index._values
338+
axis_levels = list(df.index._values)
351339

352340
if levels is not None:
353341
if isinstance(levels, (str, int)):
@@ -357,17 +345,16 @@ def ensure_key_mapped_dataframe(df, key, levels=None, axis=0):
357345
else:
358346
sort_levels = axis_levels
359347

360-
new_levels = [
361-
ensure_key_mapped(
362-
Series(df._get_label_or_level_values(name, axis=axis), name=name),
363-
key,
364-
name=name,
365-
)
366-
if name in sort_levels
367-
else df._get_label_or_level_values(name, axis=axis)
348+
values = [
349+
(name, Series(df._get_label_or_level_values(name, axis=axis), name=name))
368350
for name in axis_levels
369351
]
370352

353+
new_levels = [
354+
ensure_key_mapped(series, key) if name in sort_levels else series
355+
for (name, series) in values
356+
]
357+
371358
if axis == 0:
372359
new_df = DataFrame._from_arrays(new_levels, df.columns, df.index)
373360
else:
@@ -419,11 +406,11 @@ def ensure_key_mapped_multiindex(index, key: Callable, levels=None):
419406
else:
420407
sort_levels = list(range(index.nlevels)) # satisfies mypy
421408

409+
values = [(level, index._get_level_values(level)) for level in range(index.nlevels)]
410+
422411
mapped = [
423-
ensure_key_mapped(index._get_level_values(level), key, name=level)
424-
if level in sort_levels
425-
else index._get_level_values(level)
426-
for level in range(index.nlevels)
412+
ensure_key_mapped(idx, key) if level in sort_levels else idx
413+
for (level, idx) in values
427414
]
428415

429416
labels = MultiIndex.from_arrays(mapped)
@@ -432,7 +419,7 @@ def ensure_key_mapped_multiindex(index, key: Callable, levels=None):
432419

433420

434421
def ensure_key_mapped(
435-
values, key: Union[Dict, Callable], levels=None, name=None, axis=0,
422+
values, key: Optional[Callable], levels=None, axis=0,
436423
):
437424

438425
"""
@@ -443,19 +430,19 @@ def ensure_key_mapped(
443430
Parameters
444431
----------
445432
values : Series, DataFrame, Index subclass, or ndarray
446-
key : Union[Callable, Dict[Union[str, int], Callable]
447-
key to be called on the values array. If dict, indexed by
448-
name or number of columns in values. Dict supported only for
449-
DataFrame or MultiIndex. Expected to return an object of the
450-
same shape and compatible with the original type.
433+
key : Optional[Callable]
434+
key to be called on the values array. Expected to
435+
take a level or column from values and return an
436+
object of the same shape and compatible with the original type.
437+
For MultiIndex and DataFrame, applied to rows or columns of the
438+
values array. For Series and MultiIndex, applied directly.
451439
levels : list-like, int or str, default None
452440
For MultiIndex values, level or list of levels to apply the key
453441
function to. If None, key function is applied to all levels. Other
454442
levels are left unchanged.
455-
name : str or int, default None
456-
Name used to index the key function if a dictionary.
457443
axis : int, default 0
458-
Axis to use for applying the key to DataFrame values level by level.
444+
Axis to use for applying the key to DataFrame values. 0 applies the
445+
key to columns, 1 to rows.
459446
"""
460447
from pandas.core.indexes.api import Index
461448
from pandas import DataFrame
@@ -469,18 +456,18 @@ def ensure_key_mapped(
469456
if isinstance(values, DataFrame): # apply the key to select levels
470457
return ensure_key_mapped_dataframe(values, key, levels=levels, axis=axis)
471458

472-
result = apply_key_name(values.copy(), key, name)
459+
result = key(values.copy())
473460

474461
if len(result) != len(values):
475462
raise ValueError(
476463
"User-provided `key` function must not change the shape of the array."
477464
)
478465

479466
try:
480-
if isinstance(values, Index):
467+
if isinstance(values, Index): # allow a new Index class
481468
result = Index(result)
482469
else:
483-
result = type(values)(result)
470+
result = type(values)(result) # try to recover otherwise
484471
except TypeError:
485472
raise TypeError(
486473
"User-provided `key` function returned an invalid type {} \

pandas/tests/frame/methods/test_sort_index.py

-17
Original file line numberDiff line numberDiff line change
@@ -396,20 +396,3 @@ def test_changes_length_raises(self):
396396
df = pd.DataFrame({"A": [1, 2, 3]})
397397
with pytest.raises(ValueError, match="change the shape"):
398398
df.sort_index(key=lambda x: x[:1])
399-
400-
def test_sort_index_key_dict(self):
401-
df = DataFrame({0: ["Hello", "goodbye"], 1: [0, 1], 2: [3, 4]}).set_index(
402-
[0, 1]
403-
)
404-
405-
result = df.sort_index(level=0, key=lambda col: col.str.lower())
406-
expected = df[::-1]
407-
tm.assert_frame_equal(result, expected)
408-
409-
result = df.sort_index(level=[0, 1], key={0: lambda col: col.str.lower()})
410-
expected = df[::-1]
411-
tm.assert_frame_equal(result, expected)
412-
413-
result = df.sort_index([0, 1], key={1: lambda col: -col})
414-
expected = df
415-
tm.assert_frame_equal(result, expected)

pandas/tests/frame/methods/test_sort_values.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -657,32 +657,24 @@ def test_changes_length_raises(self):
657657
with pytest.raises(ValueError, match="change the shape"):
658658
df.sort_values("A", key=lambda x: x[:1])
659659

660-
def test_sort_values_key_dict(self):
660+
def test_sort_values_key_axes(self):
661661
df = DataFrame({0: ["Hello", "goodbye"], 1: [0, 1]})
662662

663663
result = df.sort_values(0, key=lambda col: col.str.lower())
664664
expected = df[::-1]
665665
tm.assert_frame_equal(result, expected)
666666

667-
result = df.sort_values([0, 1], key={0: lambda col: col.str.lower()})
667+
result = df.sort_values(1, key=lambda col: -col)
668668
expected = df[::-1]
669669
tm.assert_frame_equal(result, expected)
670670

671-
result = df.sort_values([0, 1], key={1: lambda col: -col})
672-
expected = df
673-
tm.assert_frame_equal(result, expected)
674-
675671
def test_sort_values_key_dict_axis(self):
676672
df = DataFrame({0: ["Hello", 0], 1: ["goodbye", 1]})
677673

678674
result = df.sort_values(0, key=lambda col: col.str.lower(), axis=1)
679675
expected = df.loc[:, ::-1]
680676
tm.assert_frame_equal(result, expected)
681677

682-
result = df.sort_values([0, 1], key={0: lambda col: col.str.lower()}, axis=1)
678+
result = df.sort_values(1, key=lambda col: -col, axis=1)
683679
expected = df.loc[:, ::-1]
684680
tm.assert_frame_equal(result, expected)
685-
686-
result = df.sort_values([0, 1], key={1: lambda col: -col}, axis=1)
687-
expected = df
688-
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)