|
6 | 6 | from collections import defaultdict
|
7 | 7 | from functools import reduce
|
8 | 8 | from itertools import zip_longest
|
9 |
| -from typing import Callable, Pattern |
| 9 | +from typing import Any, Callable, Pattern |
10 | 10 |
|
11 | 11 | import numpy as np
|
12 | 12 | import pandas as pd
|
13 | 13 | import pandas_flavor as pf
|
14 |
| -from pandas.api.types import is_extension_array_dtype, is_list_like |
| 14 | +from pandas.api.types import is_extension_array_dtype, is_scalar |
15 | 15 | from pandas.core.dtypes.concat import concat_compat
|
16 | 16 |
|
17 | 17 | from janitor.functions.select import (
|
|
25 | 25 | @pf.register_dataframe_method
|
26 | 26 | def pivot_longer(
|
27 | 27 | df: pd.DataFrame,
|
28 |
| - index: list | tuple | str | Pattern = None, |
29 |
| - column_names: list | tuple | str | Pattern = None, |
| 28 | + index: Any = None, |
| 29 | + column_names: Any = None, |
30 | 30 | names_to: list | tuple | str = None,
|
31 | 31 | values_to: str = "value",
|
32 | 32 | column_level: int | str = None,
|
@@ -919,8 +919,8 @@ def _data_checks_pivot_longer(
|
919 | 919 |
|
920 | 920 | def _computations_pivot_longer(
|
921 | 921 | df: pd.DataFrame,
|
922 |
| - index: list | tuple | str | Pattern | None, |
923 |
| - column_names: list | tuple | str | Pattern | None, |
| 922 | + index: Any, |
| 923 | + column_names: Any, |
924 | 924 | names_to: list | tuple | str | None,
|
925 | 925 | values_to: str,
|
926 | 926 | column_level: int | str,
|
@@ -1865,7 +1865,7 @@ def _names_transform(
|
1865 | 1865 | @pf.register_dataframe_method
|
1866 | 1866 | def pivot_wider(
|
1867 | 1867 | df: pd.DataFrame,
|
1868 |
| - index: list | str = None, |
| 1868 | + index: Any = None, |
1869 | 1869 | names_from: list | str = None,
|
1870 | 1870 | values_from: list | str = None,
|
1871 | 1871 | flatten_levels: bool = True,
|
@@ -2061,7 +2061,7 @@ def pivot_wider(
|
2061 | 2061 |
|
2062 | 2062 | def _computations_pivot_wider(
|
2063 | 2063 | df: pd.DataFrame,
|
2064 |
| - index: list | str | None, |
| 2064 | + index: Any, |
2065 | 2065 | names_from: list | str | None,
|
2066 | 2066 | values_from: list | str | None,
|
2067 | 2067 | flatten_levels: bool,
|
@@ -2217,71 +2217,32 @@ def _data_checks_pivot_wider(
|
2217 | 2217 | checking happens.
|
2218 | 2218 | """
|
2219 | 2219 |
|
2220 |
| - is_multi_index = isinstance(df.columns, pd.MultiIndex) |
2221 |
| - if index is not None: |
2222 |
| - if is_multi_index: |
2223 |
| - if not isinstance(index, list): |
2224 |
| - raise TypeError( |
2225 |
| - "For a MultiIndex column, pass a list of tuples " |
2226 |
| - "to the index argument." |
2227 |
| - ) |
2228 |
| - index = _check_tuples_multiindex(df.columns, index, "index") |
2229 |
| - else: |
2230 |
| - if is_list_like(index): |
2231 |
| - index = list(index) |
2232 |
| - index = get_index_labels(index, df, axis="columns") |
2233 |
| - if not is_list_like(index): |
2234 |
| - index = [index] |
2235 |
| - else: |
2236 |
| - index = list(index) |
2237 |
| - |
2238 | 2220 | if names_from is None:
|
2239 | 2221 | raise ValueError(
|
2240 | 2222 | "pivot_wider() is missing 1 required argument: 'names_from'"
|
2241 | 2223 | )
|
| 2224 | + names_from = get_index_labels([names_from], df, axis="columns") |
2242 | 2225 |
|
2243 |
| - if is_multi_index: |
2244 |
| - if not isinstance(names_from, list): |
2245 |
| - raise TypeError( |
2246 |
| - "For a MultiIndex column, pass a list of tuples " |
2247 |
| - "to the names_from argument." |
2248 |
| - ) |
2249 |
| - names_from = _check_tuples_multiindex( |
2250 |
| - df.columns, names_from, "names_from" |
2251 |
| - ) |
| 2226 | + if values_from is None: |
| 2227 | + values_from_ = df.columns.difference(names_from) |
2252 | 2228 | else:
|
2253 |
| - if is_list_like(names_from): |
2254 |
| - names_from = list(names_from) |
2255 |
| - names_from = get_index_labels(names_from, df, axis="columns") |
2256 |
| - if not is_list_like(names_from): |
2257 |
| - names_from = [names_from] |
2258 |
| - else: |
2259 |
| - names_from = list(names_from) |
2260 |
| - |
2261 |
| - if values_from is not None: |
2262 |
| - if is_multi_index: |
2263 |
| - if not isinstance(values_from, list): |
2264 |
| - raise TypeError( |
2265 |
| - "For a MultiIndex column, pass a list of tuples " |
2266 |
| - "to the values_from argument." |
2267 |
| - ) |
2268 |
| - out = _check_tuples_multiindex( |
2269 |
| - df.columns, values_from, "values_from" |
2270 |
| - ) |
2271 |
| - else: |
2272 |
| - if is_list_like(values_from): |
2273 |
| - values_from = list(values_from) |
2274 |
| - out = get_index_labels(values_from, df, axis="columns") |
2275 |
| - if not is_list_like(out): |
2276 |
| - out = [out] |
2277 |
| - else: |
2278 |
| - out = list(out) |
2279 |
| - # hack to align with pd.pivot |
2280 |
| - if values_from == out[0]: |
2281 |
| - values_from = out[0] |
2282 |
| - else: |
2283 |
| - values_from = out |
| 2229 | + values_from_ = get_index_labels([values_from], df, axis="columns") |
2284 | 2230 |
|
| 2231 | + if index is None: |
| 2232 | + index = df.columns.difference(names_from).difference(values_from) |
| 2233 | + if index.empty: |
| 2234 | + index = None |
| 2235 | + else: |
| 2236 | + index = list(index) |
| 2237 | + else: |
| 2238 | + index = get_index_labels([index], df, axis="columns") |
| 2239 | + index = list(index) |
| 2240 | + names_from = list(names_from) |
| 2241 | + if is_scalar(values_from) and (values_from is not None): |
| 2242 | + if values_from == values_from_[0]: |
| 2243 | + pass |
| 2244 | + else: |
| 2245 | + values_from = list(values_from_) |
2285 | 2246 | check("flatten_levels", flatten_levels, [bool])
|
2286 | 2247 |
|
2287 | 2248 | if names_sep is not None:
|
@@ -2354,30 +2315,6 @@ def _expand(indexer, retain_categories):
|
2354 | 2315 | return indexer
|
2355 | 2316 |
|
2356 | 2317 |
|
2357 |
| -def _check_tuples_multiindex(indexer, args, param): |
2358 |
| - """ |
2359 |
| - Check entries for tuples, |
2360 |
| - if indexer is a MultiIndex. |
2361 |
| -
|
2362 |
| - Returns a list of tuples. |
2363 |
| - """ |
2364 |
| - all_tuples = (isinstance(arg, tuple) for arg in args) |
2365 |
| - if not all(all_tuples): |
2366 |
| - raise TypeError( |
2367 |
| - f"{param} must be a list of tuples " |
2368 |
| - "when the columns are a MultiIndex." |
2369 |
| - ) |
2370 |
| - |
2371 |
| - not_found = set(args).difference(indexer) |
2372 |
| - if any(not_found): |
2373 |
| - raise KeyError( |
2374 |
| - f"Tuples {*not_found,} in the {param} " |
2375 |
| - "argument do not exist in the dataframe's columns." |
2376 |
| - ) |
2377 |
| - |
2378 |
| - return args |
2379 |
| - |
2380 |
| - |
2381 | 2318 | def pivot_wider_spec(
|
2382 | 2319 | df: pd.DataFrame,
|
2383 | 2320 | spec: pd.DataFrame,
|
|
0 commit comments