Skip to content

Commit 446854a

Browse files
jbrockmendelYYYasin19
authored andcommitted
TYP: reshape (pandas-dev#48099)
1 parent 1f80a7d commit 446854a

File tree

4 files changed

+97
-41
lines changed

4 files changed

+97
-41
lines changed

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8942,7 +8942,7 @@ def explode(
89428942
if is_scalar(column) or isinstance(column, tuple):
89438943
columns = [column]
89448944
elif isinstance(column, list) and all(
8945-
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
8945+
is_scalar(c) or isinstance(c, tuple) for c in column
89468946
):
89478947
if not column:
89488948
raise ValueError("column must be nonempty")

pandas/core/generic.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ def __round__(self: NDFrameT, decimals: int = 0) -> NDFrameT:
16131613
# have consistent precedence and validation logic throughout the library.
16141614

16151615
@final
1616-
def _is_level_reference(self, key, axis=0):
1616+
def _is_level_reference(self, key: Level, axis=0) -> bool_t:
16171617
"""
16181618
Test whether a key is a level reference for a given axis.
16191619
@@ -1625,7 +1625,7 @@ def _is_level_reference(self, key, axis=0):
16251625
16261626
Parameters
16271627
----------
1628-
key : str
1628+
key : Hashable
16291629
Potential level name for the given axis
16301630
axis : int, default 0
16311631
Axis that levels are associated with (0 for index, 1 for columns)
@@ -1644,7 +1644,7 @@ def _is_level_reference(self, key, axis=0):
16441644
)
16451645

16461646
@final
1647-
def _is_label_reference(self, key, axis=0) -> bool_t:
1647+
def _is_label_reference(self, key: Level, axis=0) -> bool_t:
16481648
"""
16491649
Test whether a key is a label reference for a given axis.
16501650
@@ -1654,8 +1654,8 @@ def _is_label_reference(self, key, axis=0) -> bool_t:
16541654
16551655
Parameters
16561656
----------
1657-
key : str
1658-
Potential label name
1657+
key : Hashable
1658+
Potential label name, i.e. Index entry.
16591659
axis : int, default 0
16601660
Axis perpendicular to the axis that labels are associated with
16611661
(0 means search for column labels, 1 means search for index labels)
@@ -1674,7 +1674,7 @@ def _is_label_reference(self, key, axis=0) -> bool_t:
16741674
)
16751675

16761676
@final
1677-
def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t:
1677+
def _is_label_or_level_reference(self, key: Level, axis: int = 0) -> bool_t:
16781678
"""
16791679
Test whether a key is a label or level reference for a given axis.
16801680
@@ -1685,7 +1685,7 @@ def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t:
16851685
16861686
Parameters
16871687
----------
1688-
key : str
1688+
key : Hashable
16891689
Potential label or level name
16901690
axis : int, default 0
16911691
Axis that levels are associated with (0 for index, 1 for columns)
@@ -1699,7 +1699,7 @@ def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t:
16991699
)
17001700

17011701
@final
1702-
def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
1702+
def _check_label_or_level_ambiguity(self, key: Level, axis: int = 0) -> None:
17031703
"""
17041704
Check whether `key` is ambiguous.
17051705
@@ -1708,7 +1708,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
17081708
17091709
Parameters
17101710
----------
1711-
key : str or object
1711+
key : Hashable
17121712
Label or level name.
17131713
axis : int, default 0
17141714
Axis that levels are associated with (0 for index, 1 for columns).
@@ -1717,6 +1717,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
17171717
------
17181718
ValueError: `key` is ambiguous
17191719
"""
1720+
17201721
axis = self._get_axis_number(axis)
17211722
other_axes = (ax for ax in range(self._AXIS_LEN) if ax != axis)
17221723

@@ -1743,7 +1744,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
17431744
raise ValueError(msg)
17441745

17451746
@final
1746-
def _get_label_or_level_values(self, key: Level, axis: int = 0) -> np.ndarray:
1747+
def _get_label_or_level_values(self, key: Level, axis: int = 0) -> ArrayLike:
17471748
"""
17481749
Return a 1-D array of values associated with `key`, a label or level
17491750
from the given `axis`.
@@ -1758,14 +1759,14 @@ def _get_label_or_level_values(self, key: Level, axis: int = 0) -> np.ndarray:
17581759
17591760
Parameters
17601761
----------
1761-
key : str
1762+
key : Hashable
17621763
Label or level name.
17631764
axis : int, default 0
17641765
Axis that levels are associated with (0 for index, 1 for columns)
17651766
17661767
Returns
17671768
-------
1768-
values : np.ndarray
1769+
np.ndarray or ExtensionArray
17691770
17701771
Raises
17711772
------

pandas/core/reshape/merge.py

+53-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import (
1212
TYPE_CHECKING,
1313
Hashable,
14+
Sequence,
1415
cast,
1516
)
1617
import uuid
@@ -25,6 +26,7 @@
2526
lib,
2627
)
2728
from pandas._typing import (
29+
AnyArrayLike,
2830
ArrayLike,
2931
DtypeObj,
3032
IndexLabel,
@@ -609,6 +611,21 @@ class _MergeOperation:
609611
"""
610612

611613
_merge_type = "merge"
614+
how: str
615+
on: IndexLabel | None
616+
# left_on/right_on may be None when passed, but in validate_specification
617+
# get replaced with non-None.
618+
left_on: Sequence[Hashable | AnyArrayLike]
619+
right_on: Sequence[Hashable | AnyArrayLike]
620+
left_index: bool
621+
right_index: bool
622+
axis: int
623+
bm_axis: int
624+
sort: bool
625+
suffixes: Suffixes
626+
copy: bool
627+
indicator: bool
628+
validate: str | None
612629

613630
def __init__(
614631
self,
@@ -819,8 +836,16 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None:
819836
self.join_names, self.left_on, self.right_on
820837
):
821838
if (
822-
self.orig_left._is_level_reference(left_key)
823-
and self.orig_right._is_level_reference(right_key)
839+
# Argument 1 to "_is_level_reference" of "NDFrame" has incompatible
840+
# type "Union[Hashable, ExtensionArray, Index, Series]"; expected
841+
# "Hashable"
842+
self.orig_left._is_level_reference(left_key) # type: ignore[arg-type]
843+
# Argument 1 to "_is_level_reference" of "NDFrame" has incompatible
844+
# type "Union[Hashable, ExtensionArray, Index, Series]"; expected
845+
# "Hashable"
846+
and self.orig_right._is_level_reference(
847+
right_key # type: ignore[arg-type]
848+
)
824849
and left_key == right_key
825850
and name not in result.index.names
826851
):
@@ -1049,13 +1074,13 @@ def _get_merge_keys(self):
10491074
10501075
Returns
10511076
-------
1052-
left_keys, right_keys
1077+
left_keys, right_keys, join_names
10531078
"""
1054-
left_keys = []
1055-
right_keys = []
1056-
# error: Need type annotation for 'join_names' (hint: "join_names: List[<type>]
1057-
# = ...")
1058-
join_names = [] # type: ignore[var-annotated]
1079+
# left_keys, right_keys entries can actually be anything listlike
1080+
# with a 'dtype' attr
1081+
left_keys: list[AnyArrayLike] = []
1082+
right_keys: list[AnyArrayLike] = []
1083+
join_names: list[Hashable] = []
10591084
right_drop = []
10601085
left_drop = []
10611086

@@ -1078,11 +1103,16 @@ def _get_merge_keys(self):
10781103
if _any(self.left_on) and _any(self.right_on):
10791104
for lk, rk in zip(self.left_on, self.right_on):
10801105
if is_lkey(lk):
1106+
lk = cast(AnyArrayLike, lk)
10811107
left_keys.append(lk)
10821108
if is_rkey(rk):
1109+
rk = cast(AnyArrayLike, rk)
10831110
right_keys.append(rk)
10841111
join_names.append(None) # what to do?
10851112
else:
1113+
# Then we're either Hashable or a wrong-length arraylike,
1114+
# the latter of which will raise
1115+
rk = cast(Hashable, rk)
10861116
if rk is not None:
10871117
right_keys.append(right._get_label_or_level_values(rk))
10881118
join_names.append(rk)
@@ -1092,6 +1122,9 @@ def _get_merge_keys(self):
10921122
join_names.append(right.index.name)
10931123
else:
10941124
if not is_rkey(rk):
1125+
# Then we're either Hashable or a wrong-length arraylike,
1126+
# the latter of which will raise
1127+
rk = cast(Hashable, rk)
10951128
if rk is not None:
10961129
right_keys.append(right._get_label_or_level_values(rk))
10971130
else:
@@ -1104,8 +1137,12 @@ def _get_merge_keys(self):
11041137
else:
11051138
left_drop.append(lk)
11061139
else:
1140+
rk = cast(AnyArrayLike, rk)
11071141
right_keys.append(rk)
11081142
if lk is not None:
1143+
# Then we're either Hashable or a wrong-length arraylike,
1144+
# the latter of which will raise
1145+
lk = cast(Hashable, lk)
11091146
left_keys.append(left._get_label_or_level_values(lk))
11101147
join_names.append(lk)
11111148
else:
@@ -1115,9 +1152,13 @@ def _get_merge_keys(self):
11151152
elif _any(self.left_on):
11161153
for k in self.left_on:
11171154
if is_lkey(k):
1155+
k = cast(AnyArrayLike, k)
11181156
left_keys.append(k)
11191157
join_names.append(None)
11201158
else:
1159+
# Then we're either Hashable or a wrong-length arraylike,
1160+
# the latter of which will raise
1161+
k = cast(Hashable, k)
11211162
left_keys.append(left._get_label_or_level_values(k))
11221163
join_names.append(k)
11231164
if isinstance(self.right.index, MultiIndex):
@@ -1132,9 +1173,13 @@ def _get_merge_keys(self):
11321173
elif _any(self.right_on):
11331174
for k in self.right_on:
11341175
if is_rkey(k):
1176+
k = cast(AnyArrayLike, k)
11351177
right_keys.append(k)
11361178
join_names.append(None)
11371179
else:
1180+
# Then we're either Hashable or a wrong-length arraylike,
1181+
# the latter of which will raise
1182+
k = cast(Hashable, k)
11381183
right_keys.append(right._get_label_or_level_values(k))
11391184
join_names.append(k)
11401185
if isinstance(self.left.index, MultiIndex):

pandas/core/reshape/reshape.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

33
import itertools
4-
from typing import TYPE_CHECKING
4+
from typing import (
5+
TYPE_CHECKING,
6+
cast,
7+
)
58
import warnings
69

710
import numpy as np
@@ -452,7 +455,7 @@ def _unstack_multiple(data, clocs, fill_value=None):
452455
return unstacked
453456

454457

455-
def unstack(obj, level, fill_value=None):
458+
def unstack(obj: Series | DataFrame, level, fill_value=None):
456459

457460
if isinstance(level, (tuple, list)):
458461
if len(level) != 1:
@@ -489,19 +492,20 @@ def unstack(obj, level, fill_value=None):
489492
)
490493

491494

492-
def _unstack_frame(obj, level, fill_value=None):
495+
def _unstack_frame(obj: DataFrame, level, fill_value=None):
496+
assert isinstance(obj.index, MultiIndex) # checked by caller
497+
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)
498+
493499
if not obj._can_fast_transpose:
494-
unstacker = _Unstacker(obj.index, level=level)
495500
mgr = obj._mgr.unstack(unstacker, fill_value=fill_value)
496501
return obj._constructor(mgr)
497502
else:
498-
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)
499503
return unstacker.get_result(
500504
obj._values, value_columns=obj.columns, fill_value=fill_value
501505
)
502506

503507

504-
def _unstack_extension_series(series, level, fill_value):
508+
def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
505509
"""
506510
Unstack an ExtensionArray-backed Series.
507511
@@ -534,14 +538,14 @@ def _unstack_extension_series(series, level, fill_value):
534538
return result
535539

536540

537-
def stack(frame, level=-1, dropna=True):
541+
def stack(frame: DataFrame, level=-1, dropna: bool = True):
538542
"""
539543
Convert DataFrame to Series with multi-level Index. Columns become the
540544
second level of the resulting hierarchical index
541545
542546
Returns
543547
-------
544-
stacked : Series
548+
stacked : Series or DataFrame
545549
"""
546550

547551
def factorize(index):
@@ -676,8 +680,10 @@ def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex:
676680
)
677681

678682

679-
def _stack_multi_columns(frame, level_num=-1, dropna=True):
680-
def _convert_level_number(level_num: int, columns):
683+
def _stack_multi_columns(
684+
frame: DataFrame, level_num: int = -1, dropna: bool = True
685+
) -> DataFrame:
686+
def _convert_level_number(level_num: int, columns: Index):
681687
"""
682688
Logic for converting the level number to something we can safely pass
683689
to swaplevel.
@@ -690,32 +696,36 @@ def _convert_level_number(level_num: int, columns):
690696

691697
return level_num
692698

693-
this = frame.copy()
699+
this = frame.copy(deep=False)
700+
mi_cols = this.columns # cast(MultiIndex, this.columns)
701+
assert isinstance(mi_cols, MultiIndex) # caller is responsible
694702

695703
# this makes life much simpler
696-
if level_num != frame.columns.nlevels - 1:
704+
if level_num != mi_cols.nlevels - 1:
697705
# roll levels to put selected level at end
698-
roll_columns = this.columns
699-
for i in range(level_num, frame.columns.nlevels - 1):
706+
roll_columns = mi_cols
707+
for i in range(level_num, mi_cols.nlevels - 1):
700708
# Need to check if the ints conflict with level names
701709
lev1 = _convert_level_number(i, roll_columns)
702710
lev2 = _convert_level_number(i + 1, roll_columns)
703711
roll_columns = roll_columns.swaplevel(lev1, lev2)
704-
this.columns = roll_columns
712+
this.columns = mi_cols = roll_columns
705713

706-
if not this.columns._is_lexsorted():
714+
if not mi_cols._is_lexsorted():
707715
# Workaround the edge case where 0 is one of the column names,
708716
# which interferes with trying to sort based on the first
709717
# level
710-
level_to_sort = _convert_level_number(0, this.columns)
718+
level_to_sort = _convert_level_number(0, mi_cols)
711719
this = this.sort_index(level=level_to_sort, axis=1)
720+
mi_cols = this.columns
712721

713-
new_columns = _stack_multi_column_index(this.columns)
722+
mi_cols = cast(MultiIndex, mi_cols)
723+
new_columns = _stack_multi_column_index(mi_cols)
714724

715725
# time to ravel the values
716726
new_data = {}
717-
level_vals = this.columns.levels[-1]
718-
level_codes = sorted(set(this.columns.codes[-1]))
727+
level_vals = mi_cols.levels[-1]
728+
level_codes = sorted(set(mi_cols.codes[-1]))
719729
level_vals_nan = level_vals.insert(len(level_vals), None)
720730

721731
level_vals_used = np.take(level_vals_nan, level_codes)

0 commit comments

Comments
 (0)