diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index f80b65bd6..60ebbd06d 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -951,9 +951,9 @@ class DataFrame(NDFrame, OpsMixin): ) -> _DataFrameGroupByNonScalar: ... def pivot( self, - index=..., - columns=..., - values=..., + index: IndexLabel = ..., + columns: IndexLabel = ..., + values: IndexLabel = ..., ) -> DataFrame: ... def pivot_table( self, diff --git a/pandas-stubs/core/reshape/pivot.pyi b/pandas-stubs/core/reshape/pivot.pyi index dc29d5c61..bd32b8abe 100644 --- a/pandas-stubs/core/reshape/pivot.pyi +++ b/pandas-stubs/core/reshape/pivot.pyi @@ -27,9 +27,9 @@ def pivot_table( ) -> DataFrame: ... def pivot( data: DataFrame, - index: str | None = ..., - columns: str | None = ..., - values: IndexLabel | None = ..., + index: IndexLabel = ..., + columns: IndexLabel = ..., + values: IndexLabel = ..., ) -> DataFrame: ... def crosstab( index: Sequence | Series, diff --git a/tests/test_frame.py b/tests/test_frame.py index 8aed66021..f8074ba90 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -543,9 +543,26 @@ def test_types_pivot() -> None: "col4": [100, 102, 500, 600], } ) - df.pivot(index="col1", columns="col3", values="col2") - df.pivot(index="col1", columns="col3") - df.pivot(index="col1", columns="col3", values=["col2", "col4"]) + check( + assert_type( + df.pivot(index="col1", columns="col3", values="col2"), pd.DataFrame + ), + pd.DataFrame, + ) + check( + assert_type(df.pivot(index="col1", columns="col3"), pd.DataFrame), pd.DataFrame + ) + check( + assert_type( + df.pivot(index="col1", columns="col3", values=["col2", "col4"]), + pd.DataFrame, + ), + pd.DataFrame, + ) + check(assert_type(df.pivot(columns="col3"), pd.DataFrame), pd.DataFrame) + check( + assert_type(df.pivot(columns="col3", values="col2"), pd.DataFrame), pd.DataFrame + ) def test_types_groupby() -> None: