Skip to content

Commit 1f48d3d

Browse files
TYP: Add annotation for df.pivot (#32197)
1 parent b9e0935 commit 1f48d3d

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

pandas/core/indexes/multi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _verify_integrity(
387387
return new_codes
388388

389389
@classmethod
390-
def from_arrays(cls, arrays, sortorder=None, names=lib.no_default):
390+
def from_arrays(cls, arrays, sortorder=None, names=lib.no_default) -> "MultiIndex":
391391
"""
392392
Convert arrays to MultiIndex.
393393

pandas/core/reshape/pivot.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
1-
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
1+
from typing import (
2+
TYPE_CHECKING,
3+
Callable,
4+
Dict,
5+
List,
6+
Optional,
7+
Sequence,
8+
Tuple,
9+
Union,
10+
cast,
11+
)
212

313
import numpy as np
414

15+
from pandas._typing import Label
516
from pandas.util._decorators import Appender, Substitution
617

718
from pandas.core.dtypes.cast import maybe_downcast_to_dtype
@@ -424,37 +435,40 @@ def _convert_by(by):
424435

425436
@Substitution("\ndata : DataFrame")
426437
@Appender(_shared_docs["pivot"], indents=1)
427-
def pivot(data: "DataFrame", index=None, columns=None, values=None) -> "DataFrame":
438+
def pivot(
439+
data: "DataFrame",
440+
index: Optional[Union[Label, Sequence[Label]]] = None,
441+
columns: Optional[Union[Label, Sequence[Label]]] = None,
442+
values: Optional[Union[Label, Sequence[Label]]] = None,
443+
) -> "DataFrame":
428444
if columns is None:
429445
raise TypeError("pivot() missing 1 required argument: 'columns'")
430-
columns = columns if is_list_like(columns) else [columns]
446+
447+
columns = com.convert_to_list_like(columns)
431448

432449
if values is None:
433-
cols: List[str] = []
434-
if index is None:
435-
pass
436-
elif is_list_like(index):
437-
cols = list(index)
450+
if index is not None:
451+
cols = com.convert_to_list_like(index)
438452
else:
439-
cols = [index]
453+
cols = []
440454
cols.extend(columns)
441455

442456
append = index is None
443457
indexed = data.set_index(cols, append=append)
444458
else:
445459
if index is None:
446460
index = [Series(data.index, name=data.index.name)]
447-
elif is_list_like(index):
448-
index = [data[idx] for idx in index]
449461
else:
450-
index = [data[index]]
462+
index = com.convert_to_list_like(index)
463+
index = [data[idx] for idx in index]
451464

452465
data_columns = [data[col] for col in columns]
453466
index.extend(data_columns)
454467
index = MultiIndex.from_arrays(index)
455468

456469
if is_list_like(values) and not isinstance(values, tuple):
457470
# Exclude tuple because it is seen as a single column name
471+
values = cast(Sequence[Label], values)
458472
indexed = data._constructor(
459473
data[values]._values, index=index, columns=values
460474
)

0 commit comments

Comments
 (0)