Skip to content

TYP: Narrow down types of arguments (DataFrame) #52752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
43 changes: 43 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@
]
Timezone = Union[str, tzinfo]

ToTimestampHow = Literal["s", "e", "start", "end"]

# NDFrameT is stricter and ensures that the same subclass of NDFrame always is
# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a
# Series is passed into a function, a Series is always returned and if a DataFrame is
Expand Down Expand Up @@ -303,6 +305,9 @@ def closed(self) -> bool:
str, int, Sequence[Union[str, int]], Mapping[Hashable, Union[str, int]]
]

# Arguments for nsmalles and n_largest
NsmallestNlargestKeep = Literal["first", "last", "all"]

# Arguments for fillna()
FillnaOptions = Literal["backfill", "bfill", "ffill", "pad"]

Expand Down Expand Up @@ -372,9 +377,32 @@ def closed(self) -> bool:

# merge
MergeHow = Literal["left", "right", "inner", "outer", "cross"]
MergeValidate = Literal[
"one_to_one",
"1:1",
"one_to_many",
"1:m",
"many_to_one",
"m:1",
"many_to_many",
"m:m",
]

# join
JoinHow = Literal["left", "right", "inner", "outer"]
JoinValidate = Literal[
"one_to_one",
"1:1",
"one_to_many",
"1:m",
"many_to_one",
"m:1",
"many_to_many",
"m:m",
]

# reindex
ReindexMethod = Union[FillnaOptions ,Literal["nearest"]]

MatplotlibColor = Union[str, Sequence[float]]
TimeGrouperOrigin = Union[
Expand All @@ -390,3 +418,18 @@ def closed(self) -> bool:
]
AlignJoin = Literal["outer", "inner", "left", "right"]
DtypeBackend = Literal["pyarrow", "numpy_nullable"]

# update
UpdateJoin = Literal["left"]

# applymap
NaAction = Literal["None", "ignore"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your pr

this isn't quite right, you can't pass "None"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for that, resolved in 7e1c7c9


# from_dict[
FromDictOrient = Literal["columns", "index", "tight"]

# to_gbc
ToGbqIfexist = Literal["fail", "replace", "append"]

# to_stata
ToStataByteorder = Literal[">", "<", "little", "big"]
53 changes: 34 additions & 19 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,17 @@
FloatFormatType,
FormattersType,
Frequency,
FromDictOrient,
IgnoreRaise,
IndexKeyFunc,
IndexLabel,
JoinValidate,
Level,
MergeHow,
MergeValidate,
NaAction,
NaPosition,
NsmallestNlargestKeep,
PythonFuncType,
QuantileInterpolation,
ReadBuffer,
Expand All @@ -234,9 +239,15 @@
SortKind,
StorageOptions,
Suffixes,
ToGbqIfexist,
ToStataByteorder,
ToTimestampHow,
UpdateJoin,
ValueKeyFunc,
WriteBuffer,
npt,
ReindexMethod,
XMLParsers,
)

from pandas.core.groupby.generic import DataFrameGroupBy
Expand Down Expand Up @@ -1637,7 +1648,7 @@ def __rmatmul__(self, other) -> DataFrame:
def from_dict(
cls,
data: dict,
orient: str = "columns",
orient: FromDictOrient = "columns",
dtype: Dtype | None = None,
columns: Axes | None = None,
) -> DataFrame:
Expand Down Expand Up @@ -1981,7 +1992,7 @@ def to_gbq(
project_id: str | None = None,
chunksize: int | None = None,
reauth: bool = False,
if_exists: str = "fail",
if_exists: ToGbqIfexist = "fail",
auth_local_webserver: bool = True,
table_schema: list[dict[str, str]] | None = None,
location: str | None = None,
Expand Down Expand Up @@ -2535,7 +2546,7 @@ def to_stata(
*,
convert_dates: dict[Hashable, str] | None = None,
write_index: bool = True,
byteorder: str | None = None,
byteorder: ToStataByteorder | None = None,
time_stamp: datetime.datetime | None = None,
data_label: str | None = None,
variable_labels: dict[Hashable, str] | None = None,
Expand Down Expand Up @@ -2763,7 +2774,7 @@ def to_markdown(
def to_parquet(
self,
path: None = ...,
engine: str = ...,
engine: Literal["auto", "pyarrow", "fastparquet"] = ...,
compression: str | None = ...,
index: bool | None = ...,
partition_cols: list[str] | None = ...,
Expand All @@ -2776,7 +2787,7 @@ def to_parquet(
def to_parquet(
self,
path: FilePath | WriteBuffer[bytes],
engine: str = ...,
engine: Literal["auto", "pyarrow", "fastparquet"] = ...,
compression: str | None = ...,
index: bool | None = ...,
partition_cols: list[str] | None = ...,
Expand All @@ -2789,7 +2800,7 @@ def to_parquet(
def to_parquet(
self,
path: FilePath | WriteBuffer[bytes] | None = None,
engine: str = "auto",
engine: Literal["auto", "pyarrow", "fastparquet"] = "auto",
compression: str | None = "snappy",
index: bool | None = None,
partition_cols: list[str] | None = None,
Expand Down Expand Up @@ -2919,7 +2930,7 @@ def to_orc(
we refer to objects with a write() method, such as a file handle
(e.g. via builtin open function). If path is None,
a bytes object is returned.
engine : str, default 'pyarrow'
engine : {'pyarrow'}, default 'pyarrow'
ORC library to use. Pyarrow must be >= 7.0.0.
index : bool, optional
If ``True``, include the dataframe's index(es) in the file output.
Expand Down Expand Up @@ -3155,7 +3166,7 @@ def to_xml(
encoding: str = "utf-8",
xml_declaration: bool | None = True,
pretty_print: bool | None = True,
parser: str | None = "lxml",
parser: XMLParsers | None = "lxml",
stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None,
compression: CompressionOptions = "infer",
storage_options: StorageOptions = None,
Expand Down Expand Up @@ -4988,7 +4999,7 @@ def reindex(
index=None,
columns=None,
axis: Axis | None = None,
method: str | None = None,
method: ReindexMethod | None = None,
copy: bool | None = None,
level: Level | None = None,
fill_value: Scalar | None = np.nan,
Expand Down Expand Up @@ -6521,8 +6532,8 @@ def sort_values(
axis: Axis = ...,
ascending=...,
inplace: Literal[False] = ...,
kind: str = ...,
na_position: str = ...,
kind: SortKind = ...,
na_position: NaPosition = ...,
ignore_index: bool = ...,
key: ValueKeyFunc = ...,
) -> DataFrame:
Expand Down Expand Up @@ -7077,7 +7088,9 @@ def value_counts(

return counts

def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
def nlargest(
self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first"
) -> DataFrame:
"""
Return the first `n` rows ordered by `columns` in descending order.

Expand Down Expand Up @@ -7184,7 +7197,9 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram
"""
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()

def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
def nsmallest(
self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first"
) -> DataFrame:
"""
Return the first `n` rows ordered by `columns` in ascending order.

Expand Down Expand Up @@ -8348,10 +8363,10 @@ def combiner(x, y):
def update(
self,
other,
join: str = "left",
join: UpdateJoin = "left",
overwrite: bool = True,
filter_func=None,
errors: str = "ignore",
errors: IgnoreRaise = "ignore",
) -> None:
"""
Modify in place using non-NA values from another DataFrame.
Expand Down Expand Up @@ -9857,7 +9872,7 @@ def infer(x):
return self.apply(infer).__finalize__(self, "map")

def applymap(
self, func: PythonFuncType, na_action: str | None = None, **kwargs
self, func: PythonFuncType, na_action: NaAction = None, **kwargs
) -> DataFrame:
"""
Apply a function to a Dataframe elementwise.
Expand Down Expand Up @@ -9969,7 +9984,7 @@ def join(
lsuffix: str = "",
rsuffix: str = "",
sort: bool = False,
validate: str | None = None,
validate: JoinValidate | None = None,
) -> DataFrame:
"""
Join columns of another DataFrame.
Expand Down Expand Up @@ -10211,7 +10226,7 @@ def merge(
suffixes: Suffixes = ("_x", "_y"),
copy: bool | None = None,
indicator: str | bool = False,
validate: str | None = None,
validate: MergeValidate | None = None,
) -> DataFrame:
from pandas.core.reshape.merge import merge

Expand Down Expand Up @@ -11506,7 +11521,7 @@ def quantile(
def to_timestamp(
self,
freq: Frequency | None = None,
how: str = "start",
how: ToTimestampHow = "start",
axis: Axis = 0,
copy: bool | None = None,
) -> DataFrame:
Expand Down