Skip to content

Commit 54b15c3

Browse files
authored
remove HashableT in frame.pyi where possible (pandas-dev#1104)
* remove HashableT in frame.pyi where possible * fix to_records, update pyright version * fix up Hashable refs and add tests
1 parent 583d198 commit 54b15c3

File tree

3 files changed

+87
-24
lines changed

3 files changed

+87
-24
lines changed

pandas-stubs/core/frame.pyi

+50-23
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ from pandas._typing import (
8888
HashableT,
8989
HashableT1,
9090
HashableT2,
91-
HashableT3,
9291
IgnoreRaise,
9392
IndexingInt,
9493
IndexLabel,
@@ -175,13 +174,13 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
175174
@overload
176175
def __getitem__(self, idx: Scalar) -> Series | _T: ...
177176
@overload
178-
def __getitem__(
177+
def __getitem__( # type: ignore[overload-overlap]
179178
self,
180179
idx: (
181180
IndexType
182181
| MaskType
183-
| Callable[[DataFrame], IndexType | MaskType | list[HashableT]]
184-
| list[HashableT]
182+
| Callable[[DataFrame], IndexType | MaskType | Sequence[Hashable]]
183+
| list[Hashable]
185184
| tuple[
186185
IndexType
187186
| MaskType
@@ -236,7 +235,7 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
236235
@overload
237236
def __setitem__(
238237
self,
239-
idx: tuple[_IndexSliceTuple, HashableT],
238+
idx: tuple[_IndexSliceTuple, Hashable],
240239
value: Scalar | NAType | NaTType | ArrayLike | Series | list | None,
241240
) -> None: ...
242241

@@ -438,6 +437,24 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
438437
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
439438
) = ...,
440439
) -> np.recarray: ...
440+
@overload
441+
def to_stata(
442+
self,
443+
path: FilePath | WriteBuffer[bytes],
444+
*,
445+
convert_dates: dict[HashableT1, StataDateFormat] | None = ...,
446+
write_index: _bool = ...,
447+
byteorder: Literal["<", ">", "little", "big"] | None = ...,
448+
time_stamp: dt.datetime | None = ...,
449+
data_label: _str | None = ...,
450+
variable_labels: dict[HashableT2, str] | None = ...,
451+
version: Literal[117, 118, 119],
452+
convert_strl: SequenceNotStr[Hashable] | None = ...,
453+
compression: CompressionOptions = ...,
454+
storage_options: StorageOptions = ...,
455+
value_labels: dict[Hashable, dict[float, str]] | None = ...,
456+
) -> None: ...
457+
@overload
441458
def to_stata(
442459
self,
443460
path: FilePath | WriteBuffer[bytes],
@@ -449,7 +466,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
449466
data_label: _str | None = ...,
450467
variable_labels: dict[HashableT2, str] | None = ...,
451468
version: Literal[114, 117, 118, 119] | None = ...,
452-
convert_strl: list[HashableT3] | None = ...,
469+
convert_strl: None = ...,
453470
compression: CompressionOptions = ...,
454471
storage_options: StorageOptions = ...,
455472
value_labels: dict[Hashable, dict[float, str]] | None = ...,
@@ -462,7 +479,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
462479
engine: ParquetEngine = ...,
463480
compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = ...,
464481
index: bool | None = ...,
465-
partition_cols: list[HashableT] | None = ...,
482+
partition_cols: Sequence[Hashable] | None = ...,
466483
storage_options: StorageOptions = ...,
467484
**kwargs: Any,
468485
) -> None: ...
@@ -473,7 +490,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
473490
engine: ParquetEngine = ...,
474491
compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = ...,
475492
index: bool | None = ...,
476-
partition_cols: list[HashableT] | None = ...,
493+
partition_cols: Sequence[Hashable] | None = ...,
477494
storage_options: StorageOptions = ...,
478495
**kwargs: Any,
479496
) -> bytes: ...
@@ -499,7 +516,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
499516
def to_html(
500517
self,
501518
buf: FilePath | WriteBuffer[str],
502-
columns: list[HashableT] | Index | Series | None = ...,
519+
columns: SequenceNotStr[Hashable] | Index | Series | None = ...,
503520
col_space: ColspaceArgType | None = ...,
504521
header: _bool = ...,
505522
index: _bool = ...,
@@ -546,7 +563,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
546563
def to_html(
547564
self,
548565
buf: None = ...,
549-
columns: Sequence[HashableT] | None = ...,
566+
columns: Sequence[Hashable] | None = ...,
550567
col_space: ColspaceArgType | None = ...,
551568
header: _bool = ...,
552569
index: _bool = ...,
@@ -597,8 +614,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
597614
root_name: str = ...,
598615
row_name: str = ...,
599616
na_rep: str | None = ...,
600-
attr_cols: list[HashableT1] | None = ...,
601-
elem_cols: list[HashableT2] | None = ...,
617+
attr_cols: SequenceNotStr[Hashable] | None = ...,
618+
elem_cols: SequenceNotStr[Hashable] | None = ...,
602619
namespaces: dict[str | None, str] | None = ...,
603620
prefix: str | None = ...,
604621
encoding: str = ...,
@@ -617,8 +634,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
617634
root_name: str | None = ...,
618635
row_name: str | None = ...,
619636
na_rep: str | None = ...,
620-
attr_cols: list[HashableT1] | None = ...,
621-
elem_cols: list[HashableT2] | None = ...,
637+
attr_cols: list[Hashable] | None = ...,
638+
elem_cols: list[Hashable] | None = ...,
622639
namespaces: dict[str | None, str] | None = ...,
623640
prefix: str | None = ...,
624641
encoding: str = ...,
@@ -846,7 +863,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
846863
def set_index(
847864
self,
848865
keys: (
849-
Label | Series | Index | np.ndarray | Iterator[HashableT] | list[HashableT]
866+
Label
867+
| Series
868+
| Index
869+
| np.ndarray
870+
| Iterator[Hashable]
871+
| Sequence[Hashable]
850872
),
851873
*,
852874
drop: _bool = ...,
@@ -858,7 +880,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
858880
def set_index(
859881
self,
860882
keys: (
861-
Label | Series | Index | np.ndarray | Iterator[HashableT] | list[HashableT]
883+
Label
884+
| Series
885+
| Index
886+
| np.ndarray
887+
| Iterator[Hashable]
888+
| Sequence[Hashable]
862889
),
863890
*,
864891
drop: _bool = ...,
@@ -876,7 +903,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
876903
col_fill: Hashable = ...,
877904
inplace: Literal[True],
878905
allow_duplicates: _bool = ...,
879-
names: Hashable | list[HashableT] = ...,
906+
names: Hashable | Sequence[Hashable] = ...,
880907
) -> None: ...
881908
@overload
882909
def reset_index(
@@ -888,7 +915,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
888915
drop: _bool = ...,
889916
inplace: Literal[False] = ...,
890917
allow_duplicates: _bool = ...,
891-
names: Hashable | list[HashableT] = ...,
918+
names: Hashable | Sequence[Hashable] = ...,
892919
) -> Self: ...
893920
@overload
894921
def reset_index(
@@ -900,7 +927,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
900927
col_level: int | _str = ...,
901928
col_fill: Hashable = ...,
902929
allow_duplicates: _bool = ...,
903-
names: Hashable | list[HashableT] = ...,
930+
names: Hashable | Sequence[Hashable] = ...,
904931
) -> Self | None: ...
905932
def isna(self) -> Self: ...
906933
def isnull(self) -> Self: ...
@@ -1681,7 +1708,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
16811708
def columns(self) -> Index[str]: ...
16821709
@columns.setter # setter needs to be right next to getter; otherwise mypy complains
16831710
def columns(
1684-
self, cols: AnyArrayLike | list[HashableT] | tuple[HashableT, ...]
1711+
self, cols: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...]
16851712
) -> None: ...
16861713
@property
16871714
def dtypes(self) -> Series: ...
@@ -2359,8 +2386,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
23592386
def to_string(
23602387
self,
23612388
buf: FilePath | WriteBuffer[str],
2362-
columns: list[HashableT1] | Index | Series | None = ...,
2363-
col_space: int | list[int] | dict[HashableT2, int] | None = ...,
2389+
columns: SequenceNotStr[Hashable] | Index | Series | None = ...,
2390+
col_space: int | list[int] | dict[HashableT, int] | None = ...,
23642391
header: _bool | list[_str] | tuple[str, ...] = ...,
23652392
index: _bool = ...,
23662393
na_rep: _str = ...,
@@ -2382,7 +2409,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
23822409
def to_string(
23832410
self,
23842411
buf: None = ...,
2385-
columns: list[HashableT] | Index | Series | None = ...,
2412+
columns: Sequence[Hashable] | Index | Series | None = ...,
23862413
col_space: int | list[int] | dict[Hashable, int] | None = ...,
23872414
header: _bool | Sequence[_str] = ...,
23882415
index: _bool = ...,

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mypy = "1.14.1"
3838
pandas = "2.2.3"
3939
pyarrow = ">=10.0.1"
4040
pytest = ">=7.1.2"
41-
pyright = ">= 1.1.391"
41+
pyright = ">= 1.1.393"
4242
poethepoet = ">=0.16.5"
4343
loguru = ">=0.6.0"
4444
typing-extensions = ">=4.4.0"

tests/test_frame.py

+36
Original file line numberDiff line numberDiff line change
@@ -3091,6 +3091,11 @@ def test_to_records() -> None:
30913091
),
30923092
np.recarray,
30933093
)
3094+
dtypes = {"col1": np.int8, "col2": np.int16}
3095+
check(
3096+
assert_type(DF.to_records(False, dtypes), np.recarray),
3097+
np.recarray,
3098+
)
30943099

30953100

30963101
def test_to_dict() -> None:
@@ -3815,6 +3820,37 @@ def _constructor(self) -> type[MyClass]:
38153820
check(assert_type(df[["a", "b"]], MyClass), MyClass)
38163821

38173822

3823+
def test_hashable_args() -> None:
3824+
# GH 1104
3825+
df = pd.DataFrame([["abc"]], columns=["test"], index=["ind"])
3826+
test = ["test"]
3827+
3828+
with ensure_clean() as path:
3829+
3830+
df.to_stata(path, version=117, convert_strl=test)
3831+
df.to_stata(path, version=117, convert_strl=["test"])
3832+
3833+
df.to_html(path, columns=test)
3834+
df.to_html(path, columns=["test"])
3835+
3836+
df.to_xml(path, attr_cols=test)
3837+
df.to_xml(path, attr_cols=["test"])
3838+
3839+
df.to_xml(path, elem_cols=test)
3840+
df.to_xml(path, elem_cols=["test"])
3841+
3842+
# Next lines should work, but it is a mypy bug
3843+
# https://github.com/python/mypy/issues/3004
3844+
# pyright accepts this, so we only type check for pyright,
3845+
# and also test the code with pytest
3846+
df.columns = test # type: ignore[assignment]
3847+
df.columns = ["test"] # type: ignore[assignment]
3848+
3849+
testDict = {"test": 1}
3850+
df.to_string("test", col_space=testDict)
3851+
df.to_string("test", col_space={"test": 1})
3852+
3853+
38183854
# GH 906
38193855
@pd.api.extensions.register_dataframe_accessor("geo")
38203856
class GeoAccessor: ...

0 commit comments

Comments
 (0)