Skip to content

ENH: add dict to return type of func argument of DataFrame#apply #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 19, 2022
30 changes: 21 additions & 9 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ class DataFrame(NDFrame, OpsMixin):
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
axis: AxisTypeIndex = ...,
raw: _bool = ...,
result_type: Literal[None] = ...,
result_type: None = ...,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not necessarily have to be Literal[None], just None is OK, and so are others.

args=...,
**kwargs,
) -> DataFrame: ...
Expand All @@ -1106,10 +1106,22 @@ class DataFrame(NDFrame, OpsMixin):
f: Callable[..., S1],
axis: AxisTypeIndex = ...,
raw: _bool = ...,
result_type: Literal[None] = ...,
result_type: None = ...,
args=...,
**kwargs,
) -> Series[S1]: ...
# Since non-scalar type T is not supported in Series[T],
# we separate this overload from the above one
@overload
def apply(
self,
f: Callable[..., Mapping],
axis: AxisTypeIndex = ...,
raw: _bool = ...,
result_type: None = ...,
args=...,
**kwargs,
) -> Series: ...

# apply() overloads with keyword result_type, and axis does not matter
@overload
Expand All @@ -1126,7 +1138,7 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Mapping],
axis: AxisType = ...,
raw: _bool = ...,
args=...,
Expand All @@ -1137,7 +1149,7 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr],
f: Callable[..., ListLikeExceptSeriesAndStr | Mapping],
axis: AxisType = ...,
raw: _bool = ...,
args=...,
Expand All @@ -1148,7 +1160,7 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Scalar],
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Scalar | Mapping],
axis: AxisType = ...,
raw: _bool = ...,
args=...,
Expand Down Expand Up @@ -1176,7 +1188,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
f: Callable[..., S1],
raw: _bool = ...,
result_type: Literal[None] = ...,
result_type: None = ...,
args=...,
*,
axis: AxisTypeColumn,
Expand All @@ -1185,9 +1197,9 @@ class DataFrame(NDFrame, OpsMixin):
@overload
def apply(
self,
f: Callable[..., ListLikeExceptSeriesAndStr],
f: Callable[..., ListLikeExceptSeriesAndStr | Mapping],
raw: _bool = ...,
result_type: Literal[None] = ...,
result_type: None = ...,
args=...,
*,
axis: AxisTypeColumn,
Expand All @@ -1198,7 +1210,7 @@ class DataFrame(NDFrame, OpsMixin):
self,
f: Callable[..., Series],
raw: _bool = ...,
result_type: Literal[None] = ...,
result_type: None = ...,
args=...,
*,
axis: AxisTypeColumn,
Expand Down
144 changes: 78 additions & 66 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ def returns_listlike_of_2(x: pd.Series) -> tuple[int, int]:
def returns_listlike_of_3(x: pd.Series) -> tuple[int, int, int]:
return (7, 8, 9)

def returns_dict(x: pd.Series) -> dict[str, int]:
return {"col4": 7, "col5": 8}

# Misc checks
check(assert_type(df.apply(np.exp), pd.DataFrame), pd.DataFrame)
check(assert_type(df.apply(str), "pd.Series[str]"), pd.Series, str)
Expand All @@ -489,22 +492,15 @@ def gethead(s: pd.Series, y: int) -> pd.Series:

# Check various return types for default result_type (None) with default axis (0)
check(assert_type(df.apply(returns_scalar), "pd.Series[int]"), pd.Series, int)
check(
assert_type(df.apply(returns_series), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(df.apply(returns_listlike_of_3), pd.DataFrame),
pd.DataFrame,
)
check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame)
check(assert_type(df.apply(returns_listlike_of_3), pd.DataFrame), pd.DataFrame)
check(assert_type(df.apply(returns_dict), pd.Series), pd.Series)

# Check various return types for result_type="expand" with default axis (0)
check(
assert_type(
# Note that technically it does not make sense to pass a result_type of "expand" to a scalar return
df.apply(returns_scalar, result_type="expand"),
"pd.Series[int]",
),
# Note that technically it does not make sense
# to pass a result_type of "expand" to a scalar return
assert_type(df.apply(returns_scalar, result_type="expand"), "pd.Series[int]"),
pd.Series,
int,
)
Expand All @@ -518,71 +514,47 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
),
pd.DataFrame,
)
check(
assert_type(df.apply(returns_dict, result_type="expand"), pd.DataFrame),
pd.DataFrame,
)

# Check various return types for result_type="reduce" with default axis (0)
check(
assert_type(
# Note that technically it does not make sense to pass a result_type of "reduce" to a scalar return
df.apply(returns_scalar, result_type="reduce"),
"pd.Series[int]",
),
# Note that technically it does not make sense
# to pass a result_type of "reduce" to a scalar return
assert_type(df.apply(returns_scalar, result_type="reduce"), "pd.Series[int]"),
pd.Series,
int,
)
check(
assert_type(
# Note that technically it does not make sense to pass a result_type of "reduce" to a series return
df.apply(returns_series, result_type="reduce"),
pd.Series,
),
# Note that technically it does not make sense
# to pass a result_type of "reduce" to a series return
assert_type(df.apply(returns_series, result_type="reduce"), pd.Series),
pd.Series, # This technically returns a pd.Series[pd.Series], but typing does not support that
)
check(
assert_type(df.apply(returns_listlike_of_3, result_type="reduce"), pd.Series),
pd.Series,
)

# Check various return types for result_type="broadcast" with default axis (0)
check(
assert_type(
# Note that technically it does not make sense to pass a result_type of "broadcast" to a scalar return
df.apply(returns_scalar, result_type="broadcast"),
pd.DataFrame,
),
pd.DataFrame,
)
check(
assert_type(df.apply(returns_series, result_type="broadcast"), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(
# Can only broadcast a list-like of 2 elements, not 3, because there are 2 rows
df.apply(returns_listlike_of_2, result_type="broadcast"),
pd.DataFrame,
),
pd.DataFrame,
Comment on lines -547 to -564
Copy link
Contributor Author

@skatsuta skatsuta Dec 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These assertions are moved to the bottom so that test cases for result_type="broadcast" are grouped together.

assert_type(df.apply(returns_dict, result_type="reduce"), pd.Series), pd.Series
)

# Check various return types for default result_type (None) with axis=1
check(
assert_type(df.apply(returns_scalar, axis=1), "pd.Series[int]"), pd.Series, int
)
check(
assert_type(df.apply(returns_series, axis=1), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(df.apply(returns_listlike_of_3, axis=1), pd.Series),
pd.Series,
)
check(assert_type(df.apply(returns_series, axis=1), pd.DataFrame), pd.DataFrame)
check(assert_type(df.apply(returns_listlike_of_3, axis=1), pd.Series), pd.Series)
check(assert_type(df.apply(returns_dict, axis=1), pd.Series), pd.Series)

# Check various return types for result_type="expand" with axis=1
check(
# Note that technically it does not make sense
# to pass a result_type of "expand" to a scalar return
assert_type(
# Note that technically it does not make sense to pass a result_type of "expand" to a scalar return
df.apply(returns_scalar, axis=1, result_type="expand"),
"pd.Series[int]",
df.apply(returns_scalar, axis=1, result_type="expand"), "pd.Series[int]"
),
pd.Series,
int,
Expand All @@ -599,22 +571,26 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
),
pd.DataFrame,
)
check(
assert_type(df.apply(returns_dict, axis=1, result_type="expand"), pd.DataFrame),
pd.DataFrame,
)

# Check various return types for result_type="reduce" with axis=1
check(
# Note that technically it does not make sense
# to pass a result_type of "reduce" to a scalar return
assert_type(
# Note that technically it does not make sense to pass a result_type of "reduce" to a scalar return
df.apply(returns_scalar, axis=1, result_type="reduce"),
"pd.Series[int]",
df.apply(returns_scalar, axis=1, result_type="reduce"), "pd.Series[int]"
),
pd.Series,
int,
)
check(
# Note that technically it does not make sense
# to pass a result_type of "reduce" to a series return
assert_type(
# Note that technically it does not make sense to pass a result_type of "reduce" to a series return
df.apply(returns_series, axis=1, result_type="reduce"),
pd.DataFrame,
df.apply(returns_series, axis=1, result_type="reduce"), pd.DataFrame
),
pd.DataFrame,
)
Expand All @@ -624,13 +600,34 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
),
pd.Series,
)
check(
assert_type(df.apply(returns_dict, axis=1, result_type="reduce"), pd.Series),
pd.Series,
)

# Check various return types for result_type="broadcast" with axis=1
# Check various return types for result_type="broadcast" with axis=0 and axis=1
check(
# Note that technically it does not make sense
# to pass a result_type of "broadcast" to a scalar return
assert_type(df.apply(returns_scalar, result_type="broadcast"), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(df.apply(returns_series, result_type="broadcast"), pd.DataFrame),
pd.DataFrame,
)
check(
# Can only broadcast a list-like of 2 elements, not 3, because there are 2 rows
assert_type(
# Note that technicaly it does not make sense to pass a result_type of "broadcast" to a scalar return
df.apply(returns_scalar, axis=1, result_type="broadcast"),
pd.DataFrame,
df.apply(returns_listlike_of_2, result_type="broadcast"), pd.DataFrame
),
pd.DataFrame,
)
check(
# Note that technicaly it does not make sense
# to pass a result_type of "broadcast" to a scalar return
assert_type(
df.apply(returns_scalar, axis=1, result_type="broadcast"), pd.DataFrame
),
pd.DataFrame,
)
Expand All @@ -642,14 +639,29 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
)
check(
assert_type(
# Can only broadcast a list-like of 3 elements, not 2, as there are 3 columns
# Can only broadcast a list-like of 3 elements, not 2,
# as there are 3 columns
df.apply(returns_listlike_of_3, axis=1, result_type="broadcast"),
pd.DataFrame,
),
pd.DataFrame,
)
# Since dicts will be assigned to elements of np.ndarray inside broadcasting,
# we need to use a DataFrame with object dtype to make the assignment possible.
df2 = pd.DataFrame({"col1": ["a", "b"], "col2": ["c", "d"]})
check(
assert_type(df2.apply(returns_dict, result_type="broadcast"), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(
df2.apply(returns_dict, axis=1, result_type="broadcast"), pd.DataFrame
),
pd.DataFrame,
)

# Test various other positional/keyword argument combinations to ensure all overloads are supported
# Test various other positional/keyword argument combinations
# to ensure all overloads are supported
check(
assert_type(df.apply(returns_scalar, axis=0), "pd.Series[int]"), pd.Series, int
)
Expand Down