From 3f2de6b423a9e5907425d06ff6080397fcc67560 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 15 Aug 2022 15:29:35 -0700 Subject: [PATCH] TYP: reshape --- pandas/core/frame.py | 2 +- pandas/core/generic.py | 25 +++++++------- pandas/core/reshape/merge.py | 61 +++++++++++++++++++++++++++++----- pandas/core/reshape/reshape.py | 50 +++++++++++++++++----------- 4 files changed, 97 insertions(+), 41 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 8c4924a2483be..6444548596b75 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -9008,7 +9008,7 @@ def explode( if is_scalar(column) or isinstance(column, tuple): columns = [column] elif isinstance(column, list) and all( - map(lambda c: is_scalar(c) or isinstance(c, tuple), column) + is_scalar(c) or isinstance(c, tuple) for c in column ): if not column: raise ValueError("column must be nonempty") diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 8096b57168d8c..7900ffb8e9074 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1611,7 +1611,7 @@ def __round__(self: NDFrameT, decimals: int = 0) -> NDFrameT: # have consistent precedence and validation logic throughout the library. @final - def _is_level_reference(self, key, axis=0): + def _is_level_reference(self, key: Hashable, axis=0) -> bool_t: """ Test whether a key is a level reference for a given axis. @@ -1623,7 +1623,7 @@ def _is_level_reference(self, key, axis=0): Parameters ---------- - key : str + key : Hashable Potential level name for the given axis axis : int, default 0 Axis that levels are associated with (0 for index, 1 for columns) @@ -1642,7 +1642,7 @@ def _is_level_reference(self, key, axis=0): ) @final - def _is_label_reference(self, key, axis=0) -> bool_t: + def _is_label_reference(self, key: Hashable, axis=0) -> bool_t: """ Test whether a key is a label reference for a given axis. @@ -1652,8 +1652,8 @@ def _is_label_reference(self, key, axis=0) -> bool_t: Parameters ---------- - key : str - Potential label name + key : Hashable + Potential label name, i.e. Index entry. axis : int, default 0 Axis perpendicular to the axis that labels are associated with (0 means search for column labels, 1 means search for index labels) @@ -1672,7 +1672,7 @@ def _is_label_reference(self, key, axis=0) -> bool_t: ) @final - def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t: + def _is_label_or_level_reference(self, key: Hashable, axis: int = 0) -> bool_t: """ Test whether a key is a label or level reference for a given axis. @@ -1683,7 +1683,7 @@ def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t: Parameters ---------- - key : str + key : Hashable Potential label or level name axis : int, default 0 Axis that levels are associated with (0 for index, 1 for columns) @@ -1697,7 +1697,7 @@ def _is_label_or_level_reference(self, key: str, axis: int = 0) -> bool_t: ) @final - def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None: + def _check_label_or_level_ambiguity(self, key: Hashable, axis: int = 0) -> None: """ Check whether `key` is ambiguous. @@ -1706,7 +1706,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None: Parameters ---------- - key : str or object + key : Hashable Label or level name. axis : int, default 0 Axis that levels are associated with (0 for index, 1 for columns). @@ -1715,6 +1715,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None: ------ ValueError: `key` is ambiguous """ + axis = self._get_axis_number(axis) other_axes = (ax for ax in range(self._AXIS_LEN) if ax != axis) @@ -1741,7 +1742,7 @@ def _check_label_or_level_ambiguity(self, key, axis: int = 0) -> None: raise ValueError(msg) @final - def _get_label_or_level_values(self, key: str, axis: int = 0) -> np.ndarray: + def _get_label_or_level_values(self, key: Hashable, axis: int = 0) -> ArrayLike: """ Return a 1-D array of values associated with `key`, a label or level from the given `axis`. @@ -1756,14 +1757,14 @@ def _get_label_or_level_values(self, key: str, axis: int = 0) -> np.ndarray: Parameters ---------- - key : str + key : Hashable Label or level name. axis : int, default 0 Axis that levels are associated with (0 for index, 1 for columns) Returns ------- - values : np.ndarray + np.ndarray or ExtensionArray Raises ------ diff --git a/pandas/core/reshape/merge.py b/pandas/core/reshape/merge.py index 9298142d0ec75..42d4009f2d8f2 100644 --- a/pandas/core/reshape/merge.py +++ b/pandas/core/reshape/merge.py @@ -11,6 +11,7 @@ from typing import ( TYPE_CHECKING, Hashable, + Sequence, cast, ) import uuid @@ -25,6 +26,7 @@ lib, ) from pandas._typing import ( + AnyArrayLike, ArrayLike, DtypeObj, IndexLabel, @@ -610,6 +612,21 @@ class _MergeOperation: """ _merge_type = "merge" + how: str + on: IndexLabel | None + # left_on/right_on may be None when passed, but in validate_specification + # get replaced with non-None. + left_on: Sequence[Hashable | AnyArrayLike] + right_on: Sequence[Hashable | AnyArrayLike] + left_index: bool + right_index: bool + axis: int + bm_axis: int + sort: bool + suffixes: Suffixes + copy: bool + indicator: bool + validate: str | None def __init__( self, @@ -822,8 +839,16 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None: self.join_names, self.left_on, self.right_on ): if ( - self.orig_left._is_level_reference(left_key) - and self.orig_right._is_level_reference(right_key) + # Argument 1 to "_is_level_reference" of "NDFrame" has incompatible + # type "Union[Hashable, ExtensionArray, Index, Series]"; expected + # "Hashable" + self.orig_left._is_level_reference(left_key) # type: ignore[arg-type] + # Argument 1 to "_is_level_reference" of "NDFrame" has incompatible + # type "Union[Hashable, ExtensionArray, Index, Series]"; expected + # "Hashable" + and self.orig_right._is_level_reference( + right_key # type: ignore[arg-type] + ) and left_key == right_key and name not in result.index.names ): @@ -1052,13 +1077,13 @@ def _get_merge_keys(self): Returns ------- - left_keys, right_keys + left_keys, right_keys, join_names """ - left_keys = [] - right_keys = [] - # error: Need type annotation for 'join_names' (hint: "join_names: List[] - # = ...") - join_names = [] # type: ignore[var-annotated] + # left_keys, right_keys entries can actually be anything listlike + # with a 'dtype' attr + left_keys: list[AnyArrayLike] = [] + right_keys: list[AnyArrayLike] = [] + join_names: list[Hashable] = [] right_drop = [] left_drop = [] @@ -1081,11 +1106,16 @@ def _get_merge_keys(self): if _any(self.left_on) and _any(self.right_on): for lk, rk in zip(self.left_on, self.right_on): if is_lkey(lk): + lk = cast(AnyArrayLike, lk) left_keys.append(lk) if is_rkey(rk): + rk = cast(AnyArrayLike, rk) right_keys.append(rk) join_names.append(None) # what to do? else: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + rk = cast(Hashable, rk) if rk is not None: right_keys.append(right._get_label_or_level_values(rk)) join_names.append(rk) @@ -1095,6 +1125,9 @@ def _get_merge_keys(self): join_names.append(right.index.name) else: if not is_rkey(rk): + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + rk = cast(Hashable, rk) if rk is not None: right_keys.append(right._get_label_or_level_values(rk)) else: @@ -1107,8 +1140,12 @@ def _get_merge_keys(self): else: left_drop.append(lk) else: + rk = cast(AnyArrayLike, rk) right_keys.append(rk) if lk is not None: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + lk = cast(Hashable, lk) left_keys.append(left._get_label_or_level_values(lk)) join_names.append(lk) else: @@ -1118,9 +1155,13 @@ def _get_merge_keys(self): elif _any(self.left_on): for k in self.left_on: if is_lkey(k): + k = cast(AnyArrayLike, k) left_keys.append(k) join_names.append(None) else: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + k = cast(Hashable, k) left_keys.append(left._get_label_or_level_values(k)) join_names.append(k) if isinstance(self.right.index, MultiIndex): @@ -1135,9 +1176,13 @@ def _get_merge_keys(self): elif _any(self.right_on): for k in self.right_on: if is_rkey(k): + k = cast(AnyArrayLike, k) right_keys.append(k) join_names.append(None) else: + # Then we're either Hashable or a wrong-length arraylike, + # the latter of which will raise + k = cast(Hashable, k) right_keys.append(right._get_label_or_level_values(k)) join_names.append(k) if isinstance(self.left.index, MultiIndex): diff --git a/pandas/core/reshape/reshape.py b/pandas/core/reshape/reshape.py index 5039a29b74f1b..52b059f6b92af 100644 --- a/pandas/core/reshape/reshape.py +++ b/pandas/core/reshape/reshape.py @@ -1,7 +1,10 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + cast, +) import warnings import numpy as np @@ -452,7 +455,7 @@ def _unstack_multiple(data, clocs, fill_value=None): return unstacked -def unstack(obj, level, fill_value=None): +def unstack(obj: Series | DataFrame, level, fill_value=None): if isinstance(level, (tuple, list)): if len(level) != 1: @@ -489,19 +492,20 @@ def unstack(obj, level, fill_value=None): ) -def _unstack_frame(obj, level, fill_value=None): +def _unstack_frame(obj: DataFrame, level, fill_value=None): + assert isinstance(obj.index, MultiIndex) # checked by caller + unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor) + if not obj._can_fast_transpose: - unstacker = _Unstacker(obj.index, level=level) mgr = obj._mgr.unstack(unstacker, fill_value=fill_value) return obj._constructor(mgr) else: - unstacker = _Unstacker(obj.index, level=level, constructor=obj._constructor) return unstacker.get_result( obj._values, value_columns=obj.columns, fill_value=fill_value ) -def _unstack_extension_series(series, level, fill_value): +def _unstack_extension_series(series: Series, level, fill_value) -> DataFrame: """ Unstack an ExtensionArray-backed Series. @@ -534,14 +538,14 @@ def _unstack_extension_series(series, level, fill_value): return result -def stack(frame, level=-1, dropna=True): +def stack(frame: DataFrame, level=-1, dropna: bool = True): """ Convert DataFrame to Series with multi-level Index. Columns become the second level of the resulting hierarchical index Returns ------- - stacked : Series + stacked : Series or DataFrame """ def factorize(index): @@ -676,8 +680,10 @@ def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex: ) -def _stack_multi_columns(frame, level_num=-1, dropna=True): - def _convert_level_number(level_num: int, columns): +def _stack_multi_columns( + frame: DataFrame, level_num: int = -1, dropna: bool = True +) -> DataFrame: + def _convert_level_number(level_num: int, columns: Index): """ Logic for converting the level number to something we can safely pass to swaplevel. @@ -690,32 +696,36 @@ def _convert_level_number(level_num: int, columns): return level_num - this = frame.copy() + this = frame.copy(deep=False) + mi_cols = this.columns # cast(MultiIndex, this.columns) + assert isinstance(mi_cols, MultiIndex) # caller is responsible # this makes life much simpler - if level_num != frame.columns.nlevels - 1: + if level_num != mi_cols.nlevels - 1: # roll levels to put selected level at end - roll_columns = this.columns - for i in range(level_num, frame.columns.nlevels - 1): + roll_columns = mi_cols + for i in range(level_num, mi_cols.nlevels - 1): # Need to check if the ints conflict with level names lev1 = _convert_level_number(i, roll_columns) lev2 = _convert_level_number(i + 1, roll_columns) roll_columns = roll_columns.swaplevel(lev1, lev2) - this.columns = roll_columns + this.columns = mi_cols = roll_columns - if not this.columns._is_lexsorted(): + if not mi_cols._is_lexsorted(): # Workaround the edge case where 0 is one of the column names, # which interferes with trying to sort based on the first # level - level_to_sort = _convert_level_number(0, this.columns) + level_to_sort = _convert_level_number(0, mi_cols) this = this.sort_index(level=level_to_sort, axis=1) + mi_cols = this.columns - new_columns = _stack_multi_column_index(this.columns) + mi_cols = cast(MultiIndex, mi_cols) + new_columns = _stack_multi_column_index(mi_cols) # time to ravel the values new_data = {} - level_vals = this.columns.levels[-1] - level_codes = sorted(set(this.columns.codes[-1])) + level_vals = mi_cols.levels[-1] + level_codes = sorted(set(mi_cols.codes[-1])) level_vals_nan = level_vals.insert(len(level_vals), None) level_vals_used = np.take(level_vals_nan, level_codes)