diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index aa69c5d80..9528755a2 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -88,7 +88,6 @@ from pandas._typing import ( HashableT, HashableT1, HashableT2, - HashableT3, IgnoreRaise, IndexingInt, IndexLabel, @@ -175,13 +174,13 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): @overload def __getitem__(self, idx: Scalar) -> Series | _T: ... @overload - def __getitem__( + def __getitem__( # type: ignore[overload-overlap] self, idx: ( IndexType | MaskType - | Callable[[DataFrame], IndexType | MaskType | list[HashableT]] - | list[HashableT] + | Callable[[DataFrame], IndexType | MaskType | Sequence[Hashable]] + | list[Hashable] | tuple[ IndexType | MaskType @@ -236,7 +235,7 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): @overload def __setitem__( self, - idx: tuple[_IndexSliceTuple, HashableT], + idx: tuple[_IndexSliceTuple, Hashable], value: Scalar | NAType | NaTType | ArrayLike | Series | list | None, ) -> None: ... @@ -438,6 +437,24 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): _str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None ) = ..., ) -> np.recarray: ... + @overload + def to_stata( + self, + path: FilePath | WriteBuffer[bytes], + *, + convert_dates: dict[HashableT1, StataDateFormat] | None = ..., + write_index: _bool = ..., + byteorder: Literal["<", ">", "little", "big"] | None = ..., + time_stamp: dt.datetime | None = ..., + data_label: _str | None = ..., + variable_labels: dict[HashableT2, str] | None = ..., + version: Literal[117, 118, 119], + convert_strl: SequenceNotStr[Hashable] | None = ..., + compression: CompressionOptions = ..., + storage_options: StorageOptions = ..., + value_labels: dict[Hashable, dict[float, str]] | None = ..., + ) -> None: ... + @overload def to_stata( self, path: FilePath | WriteBuffer[bytes], @@ -449,7 +466,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): data_label: _str | None = ..., variable_labels: dict[HashableT2, str] | None = ..., version: Literal[114, 117, 118, 119] | None = ..., - convert_strl: list[HashableT3] | None = ..., + convert_strl: None = ..., compression: CompressionOptions = ..., storage_options: StorageOptions = ..., value_labels: dict[Hashable, dict[float, str]] | None = ..., @@ -462,7 +479,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): engine: ParquetEngine = ..., compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = ..., index: bool | None = ..., - partition_cols: list[HashableT] | None = ..., + partition_cols: Sequence[Hashable] | None = ..., storage_options: StorageOptions = ..., **kwargs: Any, ) -> None: ... @@ -473,7 +490,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): engine: ParquetEngine = ..., compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = ..., index: bool | None = ..., - partition_cols: list[HashableT] | None = ..., + partition_cols: Sequence[Hashable] | None = ..., storage_options: StorageOptions = ..., **kwargs: Any, ) -> bytes: ... @@ -499,7 +516,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def to_html( self, buf: FilePath | WriteBuffer[str], - columns: list[HashableT] | Index | Series | None = ..., + columns: SequenceNotStr[Hashable] | Index | Series | None = ..., col_space: ColspaceArgType | None = ..., header: _bool = ..., index: _bool = ..., @@ -546,7 +563,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def to_html( self, buf: None = ..., - columns: Sequence[HashableT] | None = ..., + columns: Sequence[Hashable] | None = ..., col_space: ColspaceArgType | None = ..., header: _bool = ..., index: _bool = ..., @@ -597,8 +614,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): root_name: str = ..., row_name: str = ..., na_rep: str | None = ..., - attr_cols: list[HashableT1] | None = ..., - elem_cols: list[HashableT2] | None = ..., + attr_cols: SequenceNotStr[Hashable] | None = ..., + elem_cols: SequenceNotStr[Hashable] | None = ..., namespaces: dict[str | None, str] | None = ..., prefix: str | None = ..., encoding: str = ..., @@ -617,8 +634,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): root_name: str | None = ..., row_name: str | None = ..., na_rep: str | None = ..., - attr_cols: list[HashableT1] | None = ..., - elem_cols: list[HashableT2] | None = ..., + attr_cols: list[Hashable] | None = ..., + elem_cols: list[Hashable] | None = ..., namespaces: dict[str | None, str] | None = ..., prefix: str | None = ..., encoding: str = ..., @@ -846,7 +863,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def set_index( self, keys: ( - Label | Series | Index | np.ndarray | Iterator[HashableT] | list[HashableT] + Label + | Series + | Index + | np.ndarray + | Iterator[Hashable] + | Sequence[Hashable] ), *, drop: _bool = ..., @@ -858,7 +880,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def set_index( self, keys: ( - Label | Series | Index | np.ndarray | Iterator[HashableT] | list[HashableT] + Label + | Series + | Index + | np.ndarray + | Iterator[Hashable] + | Sequence[Hashable] ), *, drop: _bool = ..., @@ -876,7 +903,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): col_fill: Hashable = ..., inplace: Literal[True], allow_duplicates: _bool = ..., - names: Hashable | list[HashableT] = ..., + names: Hashable | Sequence[Hashable] = ..., ) -> None: ... @overload def reset_index( @@ -888,7 +915,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): drop: _bool = ..., inplace: Literal[False] = ..., allow_duplicates: _bool = ..., - names: Hashable | list[HashableT] = ..., + names: Hashable | Sequence[Hashable] = ..., ) -> Self: ... @overload def reset_index( @@ -900,7 +927,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): col_level: int | _str = ..., col_fill: Hashable = ..., allow_duplicates: _bool = ..., - names: Hashable | list[HashableT] = ..., + names: Hashable | Sequence[Hashable] = ..., ) -> Self | None: ... def isna(self) -> Self: ... def isnull(self) -> Self: ... @@ -1681,7 +1708,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def columns(self) -> Index[str]: ... @columns.setter # setter needs to be right next to getter; otherwise mypy complains def columns( - self, cols: AnyArrayLike | list[HashableT] | tuple[HashableT, ...] + self, cols: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...] ) -> None: ... @property def dtypes(self) -> Series: ... @@ -2359,8 +2386,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def to_string( self, buf: FilePath | WriteBuffer[str], - columns: list[HashableT1] | Index | Series | None = ..., - col_space: int | list[int] | dict[HashableT2, int] | None = ..., + columns: SequenceNotStr[Hashable] | Index | Series | None = ..., + col_space: int | list[int] | dict[HashableT, int] | None = ..., header: _bool | list[_str] | tuple[str, ...] = ..., index: _bool = ..., na_rep: _str = ..., @@ -2382,7 +2409,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def to_string( self, buf: None = ..., - columns: list[HashableT] | Index | Series | None = ..., + columns: Sequence[Hashable] | Index | Series | None = ..., col_space: int | list[int] | dict[Hashable, int] | None = ..., header: _bool | Sequence[_str] = ..., index: _bool = ..., diff --git a/pyproject.toml b/pyproject.toml index 20ba168cb..451dafaa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ mypy = "1.14.1" pandas = "2.2.3" pyarrow = ">=10.0.1" pytest = ">=7.1.2" -pyright = ">= 1.1.391" +pyright = ">= 1.1.393" poethepoet = ">=0.16.5" loguru = ">=0.6.0" typing-extensions = ">=4.4.0" diff --git a/tests/test_frame.py b/tests/test_frame.py index ad7bb6450..fe2b6be1b 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -3091,6 +3091,11 @@ def test_to_records() -> None: ), np.recarray, ) + dtypes = {"col1": np.int8, "col2": np.int16} + check( + assert_type(DF.to_records(False, dtypes), np.recarray), + np.recarray, + ) def test_to_dict() -> None: @@ -3815,6 +3820,37 @@ def _constructor(self) -> type[MyClass]: check(assert_type(df[["a", "b"]], MyClass), MyClass) +def test_hashable_args() -> None: + # GH 1104 + df = pd.DataFrame([["abc"]], columns=["test"], index=["ind"]) + test = ["test"] + + with ensure_clean() as path: + + df.to_stata(path, version=117, convert_strl=test) + df.to_stata(path, version=117, convert_strl=["test"]) + + df.to_html(path, columns=test) + df.to_html(path, columns=["test"]) + + df.to_xml(path, attr_cols=test) + df.to_xml(path, attr_cols=["test"]) + + df.to_xml(path, elem_cols=test) + df.to_xml(path, elem_cols=["test"]) + + # Next lines should work, but it is a mypy bug + # https://github.com/python/mypy/issues/3004 + # pyright accepts this, so we only type check for pyright, + # and also test the code with pytest + df.columns = test # type: ignore[assignment] + df.columns = ["test"] # type: ignore[assignment] + + testDict = {"test": 1} + df.to_string("test", col_space=testDict) + df.to_string("test", col_space={"test": 1}) + + # GH 906 @pd.api.extensions.register_dataframe_accessor("geo") class GeoAccessor: ...