Skip to content

TYP: reshape #48099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8942,7 +8942,7 @@ def explode(
if is_scalar(column) or isinstance(column, tuple):
columns = [column]
elif isinstance(column, list) and all(
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
is_scalar(c) or isinstance(c, tuple) for c in column
):
if not column:
raise ValueError("column must be nonempty")
Expand Down
25 changes: 13 additions & 12 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,7 +1613,7 @@ def __round__(self: NDFrameT, decimals: int = 0) -> NDFrameT:
# have consistent precedence and validation logic throughout the library.

@final
def _is_level_reference(self, key, axis=0):
def _is_level_reference(self, key: Level, axis=0) -> bool_t:
"""
Test whether a key is a level reference for a given axis.

Expand All @@ -1625,7 +1625,7 @@ def _is_level_reference(self, key, axis=0):

Parameters
----------
key : str
key : Hashable
Potential level name for the given axis
axis : int, default 0
Axis that levels are associated with (0 for index, 1 for columns)
Expand All @@ -1644,7 +1644,7 @@ def _is_level_reference(self, key, axis=0):
)

@final
def _is_label_reference(self, key, axis=0) -> bool_t:
def _is_label_reference(self, key: Level, axis=0) -> bool_t:
"""
Test whether a key is a label reference for a given axis.

Expand All @@ -1654,8 +1654,8 @@ def _is_label_reference(self, key, axis=0) -> bool_t:

Parameters
----------
key : str
Potential label name
key : Hashable
Potential label name, i.e. Index entry.
axis : int, default 0
Axis perpendicular to the axis that labels are associated with
(0 means search for column labels, 1 means search for index labels)
Expand All @@ -1674,7 +1674,7 @@ def _is_label_reference(self, key, axis=0) -> bool_t:
)

@final
def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t:
def _is_label_or_level_reference(self, key: Level, axis: int = 0) -> bool_t:
"""
Test whether a key is a label or level reference for a given axis.

Expand All @@ -1685,7 +1685,7 @@ def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t:

Parameters
----------
key : str
key : Hashable
Potential label or level name
axis : int, default 0
Axis that levels are associated with (0 for index, 1 for columns)
Expand All @@ -1699,7 +1699,7 @@ def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t:
)

@final
def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
def _check_label_or_level_ambiguity(self, key: Level, axis: int = 0) -> None:
"""
Check whether `key` is ambiguous.

Expand All @@ -1708,7 +1708,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:

Parameters
----------
key : str or object
key : Hashable
Label or level name.
axis : int, default 0
Axis that levels are associated with (0 for index, 1 for columns).
Expand All @@ -1717,6 +1717,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
------
ValueError: `key` is ambiguous
"""

axis = self._get_axis_number(axis)
other_axes = (ax for ax in range(self._AXIS_LEN) if ax != axis)

Expand All @@ -1743,7 +1744,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None:
raise ValueError(msg)

@final
def _get_label_or_level_values(self, key: Level, axis: int = 0) -> np.ndarray:
def _get_label_or_level_values(self, key: Level, axis: int = 0) -> ArrayLike:
"""
Return a 1-D array of values associated with `key`, a label or level
from the given `axis`.
Expand All @@ -1758,14 +1759,14 @@ def _get_label_or_level_values(self, key: Level, axis: int = 0) -> np.ndarray:

Parameters
----------
key : str
key : Hashable
Label or level name.
axis : int, default 0
Axis that levels are associated with (0 for index, 1 for columns)

Returns
-------
values : np.ndarray
np.ndarray or ExtensionArray

Raises
------
Expand Down
61 changes: 53 additions & 8 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import (
TYPE_CHECKING,
Hashable,
Sequence,
cast,
)
import uuid
Expand All @@ -25,6 +26,7 @@
lib,
)
from pandas._typing import (
AnyArrayLike,
ArrayLike,
DtypeObj,
IndexLabel,
Expand Down Expand Up @@ -609,6 +611,21 @@ class _MergeOperation:
"""

_merge_type = "merge"
how: str
on: IndexLabel | None
# left_on/right_on may be None when passed, but in validate_specification
# get replaced with non-None.
left_on: Sequence[Hashable | AnyArrayLike]
right_on: Sequence[Hashable | AnyArrayLike]
left_index: bool
right_index: bool
axis: int
bm_axis: int
sort: bool
suffixes: Suffixes
copy: bool
indicator: bool
validate: str | None

def __init__(
self,
Expand Down Expand Up @@ -819,8 +836,16 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None:
self.join_names, self.left_on, self.right_on
):
if (
self.orig_left._is_level_reference(left_key)
and self.orig_right._is_level_reference(right_key)
# Argument 1 to "_is_level_reference" of "NDFrame" has incompatible
# type "Union[Hashable, ExtensionArray, Index, Series]"; expected
# "Hashable"
self.orig_left._is_level_reference(left_key) # type: ignore[arg-type]
# Argument 1 to "_is_level_reference" of "NDFrame" has incompatible
# type "Union[Hashable, ExtensionArray, Index, Series]"; expected
# "Hashable"
and self.orig_right._is_level_reference(
right_key # type: ignore[arg-type]
)
and left_key == right_key
and name not in result.index.names
):
Expand Down Expand Up @@ -1049,13 +1074,13 @@ def _get_merge_keys(self):

Returns
-------
left_keys, right_keys
left_keys, right_keys, join_names
"""
left_keys = []
right_keys = []
# error: Need type annotation for 'join_names' (hint: "join_names: List[<type>]
# = ...")
join_names = [] # type: ignore[var-annotated]
# left_keys, right_keys entries can actually be anything listlike
# with a 'dtype' attr
left_keys: list[AnyArrayLike] = []
right_keys: list[AnyArrayLike] = []
join_names: list[Hashable] = []
right_drop = []
left_drop = []

Expand All @@ -1078,11 +1103,16 @@ def _get_merge_keys(self):
if _any(self.left_on) and _any(self.right_on):
for lk, rk in zip(self.left_on, self.right_on):
if is_lkey(lk):
lk = cast(AnyArrayLike, lk)
left_keys.append(lk)
if is_rkey(rk):
rk = cast(AnyArrayLike, rk)
right_keys.append(rk)
join_names.append(None) # what to do?
else:
# Then we're either Hashable or a wrong-length arraylike,
# the latter of which will raise
rk = cast(Hashable, rk)
if rk is not None:
right_keys.append(right._get_label_or_level_values(rk))
join_names.append(rk)
Expand All @@ -1092,6 +1122,9 @@ def _get_merge_keys(self):
join_names.append(right.index.name)
else:
if not is_rkey(rk):
# Then we're either Hashable or a wrong-length arraylike,
# the latter of which will raise
rk = cast(Hashable, rk)
if rk is not None:
right_keys.append(right._get_label_or_level_values(rk))
else:
Expand All @@ -1104,8 +1137,12 @@ def _get_merge_keys(self):
else:
left_drop.append(lk)
else:
rk = cast(AnyArrayLike, rk)
right_keys.append(rk)
if lk is not None:
# Then we're either Hashable or a wrong-length arraylike,
# the latter of which will raise
lk = cast(Hashable, lk)
left_keys.append(left._get_label_or_level_values(lk))
join_names.append(lk)
else:
Expand All @@ -1115,9 +1152,13 @@ def _get_merge_keys(self):
elif _any(self.left_on):
for k in self.left_on:
if is_lkey(k):
k = cast(AnyArrayLike, k)
left_keys.append(k)
join_names.append(None)
else:
# Then we're either Hashable or a wrong-length arraylike,
# the latter of which will raise
k = cast(Hashable, k)
left_keys.append(left._get_label_or_level_values(k))
join_names.append(k)
if isinstance(self.right.index, MultiIndex):
Expand All @@ -1132,9 +1173,13 @@ def _get_merge_keys(self):
elif _any(self.right_on):
for k in self.right_on:
if is_rkey(k):
k = cast(AnyArrayLike, k)
right_keys.append(k)
join_names.append(None)
else:
# Then we're either Hashable or a wrong-length arraylike,
# the latter of which will raise
k = cast(Hashable, k)
right_keys.append(right._get_label_or_level_values(k))
join_names.append(k)
if isinstance(self.left.index, MultiIndex):
Expand Down
50 changes: 30 additions & 20 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
cast,
)
import warnings

import numpy as np
Expand Down Expand Up @@ -452,7 +455,7 @@ def _unstack_multiple(data, clocs, fill_value=None):
return unstacked


def unstack(obj, level, fill_value=None):
def unstack(obj: Series | DataFrame, level, fill_value=None):

if isinstance(level, (tuple, list)):
if len(level) != 1:
Expand Down Expand Up @@ -489,19 +492,20 @@ def unstack(obj, level, fill_value=None):
)


def _unstack_frame(obj, level, fill_value=None):
def _unstack_frame(obj: DataFrame, level, fill_value=None):
assert isinstance(obj.index, MultiIndex) # checked by caller
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)

if not obj._can_fast_transpose:
unstacker = _Unstacker(obj.index, level=level)
mgr = obj._mgr.unstack(unstacker, fill_value=fill_value)
return obj._constructor(mgr)
else:
unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor)
return unstacker.get_result(
obj._values, value_columns=obj.columns, fill_value=fill_value
)


def _unstack_extension_series(series, level, fill_value):
def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame:
"""
Unstack an ExtensionArray-backed Series.

Expand Down Expand Up @@ -534,14 +538,14 @@ def _unstack_extension_series(series, level, fill_value):
return result


def stack(frame, level=-1, dropna=True):
def stack(frame: DataFrame, level=-1, dropna: bool = True):
"""
Convert DataFrame to Series with multi-level Index. Columns become the
second level of the resulting hierarchical index

Returns
-------
stacked : Series
stacked : Series or DataFrame
"""

def factorize(index):
Expand Down Expand Up @@ -676,8 +680,10 @@ def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex:
)


def _stack_multi_columns(frame, level_num=-1, dropna=True):
def _convert_level_number(level_num: int, columns):
def _stack_multi_columns(
frame: DataFrame, level_num: int = -1, dropna: bool = True
) -> DataFrame:
def _convert_level_number(level_num: int, columns: Index):
"""
Logic for converting the level number to something we can safely pass
to swaplevel.
Expand All @@ -690,32 +696,36 @@ def _convert_level_number(level_num: int, columns):

return level_num

this = frame.copy()
this = frame.copy(deep=False)
mi_cols = this.columns # cast(MultiIndex, this.columns)
assert isinstance(mi_cols, MultiIndex) # caller is responsible

# this makes life much simpler
if level_num != frame.columns.nlevels - 1:
if level_num != mi_cols.nlevels - 1:
# roll levels to put selected level at end
roll_columns = this.columns
for i in range(level_num, frame.columns.nlevels - 1):
roll_columns = mi_cols
for i in range(level_num, mi_cols.nlevels - 1):
# Need to check if the ints conflict with level names
lev1 = _convert_level_number(i, roll_columns)
lev2 = _convert_level_number(i + 1, roll_columns)
roll_columns = roll_columns.swaplevel(lev1, lev2)
this.columns = roll_columns
this.columns = mi_cols = roll_columns

if not this.columns._is_lexsorted():
if not mi_cols._is_lexsorted():
# Workaround the edge case where 0 is one of the column names,
# which interferes with trying to sort based on the first
# level
level_to_sort = _convert_level_number(0, this.columns)
level_to_sort = _convert_level_number(0, mi_cols)
this = this.sort_index(level=level_to_sort, axis=1)
mi_cols = this.columns

new_columns = _stack_multi_column_index(this.columns)
mi_cols = cast(MultiIndex, mi_cols)
new_columns = _stack_multi_column_index(mi_cols)

# time to ravel the values
new_data = {}
level_vals = this.columns.levels[-1]
level_codes = sorted(set(this.columns.codes[-1]))
level_vals = mi_cols.levels[-1]
level_codes = sorted(set(mi_cols.codes[-1]))
level_vals_nan = level_vals.insert(len(level_vals), None)

level_vals_used = np.take(level_vals_nan, level_codes)
Expand Down