|
1 |
| -from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union |
| 1 | +from typing import ( |
| 2 | + TYPE_CHECKING, |
| 3 | + Callable, |
| 4 | + Dict, |
| 5 | + List, |
| 6 | + Optional, |
| 7 | + Sequence, |
| 8 | + Tuple, |
| 9 | + Union, |
| 10 | + cast, |
| 11 | +) |
2 | 12 |
|
3 | 13 | import numpy as np
|
4 | 14 |
|
| 15 | +from pandas._typing import Label |
5 | 16 | from pandas.util._decorators import Appender, Substitution
|
6 | 17 |
|
7 | 18 | from pandas.core.dtypes.cast import maybe_downcast_to_dtype
|
@@ -424,37 +435,40 @@ def _convert_by(by):
|
424 | 435 |
|
425 | 436 | @Substitution("\ndata : DataFrame")
|
426 | 437 | @Appender(_shared_docs["pivot"], indents=1)
|
427 |
| -def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFrame": |
| 438 | +def pivot( |
| 439 | + data: "DataFrame", |
| 440 | + index: Optional[Union[Label, Sequence[Label]]] = None, |
| 441 | + columns: Optional[Union[Label, Sequence[Label]]] = None, |
| 442 | + values: Optional[Union[Label, Sequence[Label]]] = None, |
| 443 | +) -> "DataFrame": |
428 | 444 | if columns is None:
|
429 | 445 | raise TypeError("pivot() missing 1 required argument: 'columns'")
|
430 |
| - columns = columns if is_list_like(columns) else [columns] |
| 446 | + |
| 447 | + columns = com.convert_to_list_like(columns) |
431 | 448 |
|
432 | 449 | if values is None:
|
433 |
| - cols: List[str] = [] |
434 |
| - if index is None: |
435 |
| - pass |
436 |
| - elif is_list_like(index): |
437 |
| - cols = list(index) |
| 450 | + if index is not None: |
| 451 | + cols = com.convert_to_list_like(index) |
438 | 452 | else:
|
439 |
| - cols = [index] |
| 453 | + cols = [] |
440 | 454 | cols.extend(columns)
|
441 | 455 |
|
442 | 456 | append = index is None
|
443 | 457 | indexed = data.set_index(cols, append=append)
|
444 | 458 | else:
|
445 | 459 | if index is None:
|
446 | 460 | index = [Series(data.index, name=data.index.name)]
|
447 |
| - elif is_list_like(index): |
448 |
| - index = [data[idx] for idx in index] |
449 | 461 | else:
|
450 |
| - index = [data[index]] |
| 462 | + index = com.convert_to_list_like(index) |
| 463 | + index = [data[idx] for idx in index] |
451 | 464 |
|
452 | 465 | data_columns = [data[col] for col in columns]
|
453 | 466 | index.extend(data_columns)
|
454 | 467 | index = MultiIndex.from_arrays(index)
|
455 | 468 |
|
456 | 469 | if is_list_like(values) and not isinstance(values, tuple):
|
457 | 470 | # Exclude tuple because it is seen as a single column name
|
| 471 | + values = cast(Sequence[Label], values) |
458 | 472 | indexed = data._constructor(
|
459 | 473 | data[values]._values, index=index, columns=values
|
460 | 474 | )
|
|
0 commit comments