Skip to content

Commit e09a193

Browse files
TYP: Narrow down types of arguments (DataFrame) (#52752)
* Specify method in reindex for class dataframe Specify parser in to_xml of class dataframe Update doc string to_orc in class dataframeUpdate doc string to_orc in class dataframe Specify engine in to_parquet in class dataframe * undo changes and adding None as an optional argument type for validate argument of join and merge method Change byteorder argument typing for to_stata method to literal, added definition in pandas/_typing.py Change if_exists argument typing for to_gbq method to literal, added definition in pandas/_typing.py Change orient argument typing for from_dict method to literal, added definition in pandas/_typing.py Change how argument typing for to_timestamp method to literal, added definition in pandas/_typing.py Change validate argument typing for merge and join methods to literal, added definition in pandas/_typing.py Change na_action arguments typing for applymap method to literal, added definition in pandas/_typing.py Change join and errors arguments typing for update method to litaral, added definition in pandas/_typing.py Change keep argument typing for nlargest and nsallest to litaera, added definition in pandas/_typing.py Specify the kind and na_position more precisely in sort_values, reusing type definitions in pandas/_typing.py * removing none from literal and adding it to the argument of applymap * adding reindex literal to super class NDFrame as it violates the Liskov substitution principle otherwise * adding reindex literal to super class NDFrame as it violates the Liskov substitution principle otherwise * adding literal to missing.py * ignore type for orient in from_dict method of frame due to mypy error * pulling main and resolving merge conflict --------- Co-authored-by: Patrick Schleiter <[email protected]>
1 parent dd2f0d2 commit e09a193

File tree

5 files changed

+84
-23
lines changed

5 files changed

+84
-23
lines changed

pandas/_typing.py

+43
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@
132132
]
133133
Timezone = Union[str, tzinfo]
134134

135+
ToTimestampHow = Literal["s", "e", "start", "end"]
136+
135137
# NDFrameT is stricter and ensures that the same subclass of NDFrame always is
136138
# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a
137139
# Series is passed into a function, a Series is always returned and if a DataFrame is
@@ -361,6 +363,9 @@ def closed(self) -> bool:
361363
SortKind = Literal["quicksort", "mergesort", "heapsort", "stable"]
362364
NaPosition = Literal["first", "last"]
363365

366+
# Arguments for nsmalles and n_largest
367+
NsmallestNlargestKeep = Literal["first", "last", "all"]
368+
364369
# quantile interpolation
365370
QuantileInterpolation = Literal["linear", "lower", "higher", "midpoint", "nearest"]
366371

@@ -372,9 +377,32 @@ def closed(self) -> bool:
372377

373378
# merge
374379
MergeHow = Literal["left", "right", "inner", "outer", "cross"]
380+
MergeValidate = Literal[
381+
"one_to_one",
382+
"1:1",
383+
"one_to_many",
384+
"1:m",
385+
"many_to_one",
386+
"m:1",
387+
"many_to_many",
388+
"m:m",
389+
]
375390

376391
# join
377392
JoinHow = Literal["left", "right", "inner", "outer"]
393+
JoinValidate = Literal[
394+
"one_to_one",
395+
"1:1",
396+
"one_to_many",
397+
"1:m",
398+
"many_to_one",
399+
"m:1",
400+
"many_to_many",
401+
"m:m",
402+
]
403+
404+
# reindex
405+
ReindexMethod = Union[FillnaOptions, Literal["nearest"]]
378406

379407
MatplotlibColor = Union[str, Sequence[float]]
380408
TimeGrouperOrigin = Union[
@@ -400,3 +428,18 @@ def closed(self) -> bool:
400428
"backslashreplace",
401429
"namereplace",
402430
]
431+
432+
# update
433+
UpdateJoin = Literal["left"]
434+
435+
# applymap
436+
NaAction = Literal["ignore"]
437+
438+
# from_dict
439+
FromDictOrient = Literal["columns", "index", "tight"]
440+
441+
# to_gbc
442+
ToGbqIfexist = Literal["fail", "replace", "append"]
443+
444+
# to_stata
445+
ToStataByteorder = Literal[">", "<", "little", "big"]

pandas/core/frame.py

+35-20
Original file line numberDiff line numberDiff line change
@@ -219,23 +219,34 @@
219219
FloatFormatType,
220220
FormattersType,
221221
Frequency,
222+
FromDictOrient,
222223
IgnoreRaise,
223224
IndexKeyFunc,
224225
IndexLabel,
226+
JoinValidate,
225227
Level,
226228
MergeHow,
229+
MergeValidate,
230+
NaAction,
227231
NaPosition,
232+
NsmallestNlargestKeep,
228233
PythonFuncType,
229234
QuantileInterpolation,
230235
ReadBuffer,
236+
ReindexMethod,
231237
Renamer,
232238
Scalar,
233239
Self,
234240
SortKind,
235241
StorageOptions,
236242
Suffixes,
243+
ToGbqIfexist,
244+
ToStataByteorder,
245+
ToTimestampHow,
246+
UpdateJoin,
237247
ValueKeyFunc,
238248
WriteBuffer,
249+
XMLParsers,
239250
npt,
240251
)
241252

@@ -1637,7 +1648,7 @@ def __rmatmul__(self, other) -> DataFrame:
16371648
def from_dict(
16381649
cls,
16391650
data: dict,
1640-
orient: str = "columns",
1651+
orient: FromDictOrient = "columns",
16411652
dtype: Dtype | None = None,
16421653
columns: Axes | None = None,
16431654
) -> DataFrame:
@@ -1724,7 +1735,7 @@ def from_dict(
17241735
c 2 4
17251736
"""
17261737
index = None
1727-
orient = orient.lower()
1738+
orient = orient.lower() # type: ignore[assignment]
17281739
if orient == "index":
17291740
if len(data) > 0:
17301741
# TODO speed up Series case
@@ -1981,7 +1992,7 @@ def to_gbq(
19811992
project_id: str | None = None,
19821993
chunksize: int | None = None,
19831994
reauth: bool = False,
1984-
if_exists: str = "fail",
1995+
if_exists: ToGbqIfexist = "fail",
19851996
auth_local_webserver: bool = True,
19861997
table_schema: list[dict[str, str]] | None = None,
19871998
location: str | None = None,
@@ -2535,7 +2546,7 @@ def to_stata(
25352546
*,
25362547
convert_dates: dict[Hashable, str] | None = None,
25372548
write_index: bool = True,
2538-
byteorder: str | None = None,
2549+
byteorder: ToStataByteorder | None = None,
25392550
time_stamp: datetime.datetime | None = None,
25402551
data_label: str | None = None,
25412552
variable_labels: dict[Hashable, str] | None = None,
@@ -2763,7 +2774,7 @@ def to_markdown(
27632774
def to_parquet(
27642775
self,
27652776
path: None = ...,
2766-
engine: str = ...,
2777+
engine: Literal["auto", "pyarrow", "fastparquet"] = ...,
27672778
compression: str | None = ...,
27682779
index: bool | None = ...,
27692780
partition_cols: list[str] | None = ...,
@@ -2776,7 +2787,7 @@ def to_parquet(
27762787
def to_parquet(
27772788
self,
27782789
path: FilePath | WriteBuffer[bytes],
2779-
engine: str = ...,
2790+
engine: Literal["auto", "pyarrow", "fastparquet"] = ...,
27802791
compression: str | None = ...,
27812792
index: bool | None = ...,
27822793
partition_cols: list[str] | None = ...,
@@ -2789,7 +2800,7 @@ def to_parquet(
27892800
def to_parquet(
27902801
self,
27912802
path: FilePath | WriteBuffer[bytes] | None = None,
2792-
engine: str = "auto",
2803+
engine: Literal["auto", "pyarrow", "fastparquet"] = "auto",
27932804
compression: str | None = "snappy",
27942805
index: bool | None = None,
27952806
partition_cols: list[str] | None = None,
@@ -2919,7 +2930,7 @@ def to_orc(
29192930
we refer to objects with a write() method, such as a file handle
29202931
(e.g. via builtin open function). If path is None,
29212932
a bytes object is returned.
2922-
engine : str, default 'pyarrow'
2933+
engine : {'pyarrow'}, default 'pyarrow'
29232934
ORC library to use. Pyarrow must be >= 7.0.0.
29242935
index : bool, optional
29252936
If ``True``, include the dataframe's index(es) in the file output.
@@ -3155,7 +3166,7 @@ def to_xml(
31553166
encoding: str = "utf-8",
31563167
xml_declaration: bool | None = True,
31573168
pretty_print: bool | None = True,
3158-
parser: str | None = "lxml",
3169+
parser: XMLParsers | None = "lxml",
31593170
stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None,
31603171
compression: CompressionOptions = "infer",
31613172
storage_options: StorageOptions = None,
@@ -4988,7 +4999,7 @@ def reindex(
49884999
index=None,
49895000
columns=None,
49905001
axis: Axis | None = None,
4991-
method: str | None = None,
5002+
method: ReindexMethod | None = None,
49925003
copy: bool | None = None,
49935004
level: Level | None = None,
49945005
fill_value: Scalar | None = np.nan,
@@ -6521,8 +6532,8 @@ def sort_values(
65216532
axis: Axis = ...,
65226533
ascending=...,
65236534
inplace: Literal[False] = ...,
6524-
kind: str = ...,
6525-
na_position: str = ...,
6535+
kind: SortKind = ...,
6536+
na_position: NaPosition = ...,
65266537
ignore_index: bool = ...,
65276538
key: ValueKeyFunc = ...,
65286539
) -> DataFrame:
@@ -7077,7 +7088,9 @@ def value_counts(
70777088

70787089
return counts
70797090

7080-
def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
7091+
def nlargest(
7092+
self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first"
7093+
) -> DataFrame:
70817094
"""
70827095
Return the first `n` rows ordered by `columns` in descending order.
70837096
@@ -7184,7 +7197,9 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram
71847197
"""
71857198
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()
71867199

7187-
def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
7200+
def nsmallest(
7201+
self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first"
7202+
) -> DataFrame:
71887203
"""
71897204
Return the first `n` rows ordered by `columns` in ascending order.
71907205
@@ -8348,10 +8363,10 @@ def combiner(x, y):
83488363
def update(
83498364
self,
83508365
other,
8351-
join: str = "left",
8366+
join: UpdateJoin = "left",
83528367
overwrite: bool = True,
83538368
filter_func=None,
8354-
errors: str = "ignore",
8369+
errors: IgnoreRaise = "ignore",
83558370
) -> None:
83568371
"""
83578372
Modify in place using non-NA values from another DataFrame.
@@ -9857,7 +9872,7 @@ def infer(x):
98579872
return self.apply(infer).__finalize__(self, "map")
98589873

98599874
def applymap(
9860-
self, func: PythonFuncType, na_action: str | None = None, **kwargs
9875+
self, func: PythonFuncType, na_action: NaAction | None = None, **kwargs
98619876
) -> DataFrame:
98629877
"""
98639878
Apply a function to a Dataframe elementwise.
@@ -9969,7 +9984,7 @@ def join(
99699984
lsuffix: str = "",
99709985
rsuffix: str = "",
99719986
sort: bool = False,
9972-
validate: str | None = None,
9987+
validate: JoinValidate | None = None,
99739988
) -> DataFrame:
99749989
"""
99759990
Join columns of another DataFrame.
@@ -10211,7 +10226,7 @@ def merge(
1021110226
suffixes: Suffixes = ("_x", "_y"),
1021210227
copy: bool | None = None,
1021310228
indicator: str | bool = False,
10214-
validate: str | None = None,
10229+
validate: MergeValidate | None = None,
1021510230
) -> DataFrame:
1021610231
from pandas.core.reshape.merge import merge
1021710232

@@ -11506,7 +11521,7 @@ def quantile(
1150611521
def to_timestamp(
1150711522
self,
1150811523
freq: Frequency | None = None,
11509-
how: str = "start",
11524+
how: ToTimestampHow = "start",
1151011525
axis: Axis = 0,
1151111526
copy: bool | None = None,
1151211527
) -> DataFrame:

pandas/core/generic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
NDFrameT,
7171
OpenFileErrors,
7272
RandomState,
73+
ReindexMethod,
7374
Renamer,
7475
Scalar,
7576
Self,
@@ -5154,7 +5155,7 @@ def reindex(
51545155
index=None,
51555156
columns=None,
51565157
axis: Axis | None = None,
5157-
method: str | None = None,
5158+
method: ReindexMethod | None = None,
51585159
copy: bool_t | None = None,
51595160
level: Level | None = None,
51605161
fill_value: Scalar | None = np.nan,

pandas/core/missing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Axis,
2626
AxisInt,
2727
F,
28+
ReindexMethod,
2829
npt,
2930
)
3031
from pandas.compat._optional import import_optional_dependency
@@ -949,7 +950,7 @@ def get_fill_func(method, ndim: int = 1):
949950
return {"pad": _pad_2d, "backfill": _backfill_2d}[method]
950951

951952

952-
def clean_reindex_fill_method(method) -> str | None:
953+
def clean_reindex_fill_method(method) -> ReindexMethod | None:
953954
return clean_fill_method(method, allow_nearest=True)
954955

955956

pandas/core/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
NumpySorter,
168168
NumpyValueArrayLike,
169169
QuantileInterpolation,
170+
ReindexMethod,
170171
Renamer,
171172
Scalar,
172173
Self,
@@ -4718,7 +4719,7 @@ def reindex( # type: ignore[override]
47184719
index=None,
47194720
*,
47204721
axis: Axis | None = None,
4721-
method: str | None = None,
4722+
method: ReindexMethod | None = None,
47224723
copy: bool | None = None,
47234724
level: Level | None = None,
47244725
fill_value: Scalar | None = None,

0 commit comments

Comments
 (0)