Skip to content

Commit b83f687

Browse files
authored
ENH: add dict to return type of func argument of DataFrame#apply (#470)
* ENH: add dict to return type of func argument of DataFrame#apply when result_type="expand" * tests/test_frame: replace Literal[None] with just None * Add support for apply func that returns Mapping with axis=0 and result_type=None * Replace dict with Mapping in return type of f for DataFrame#apply * Add support for apply func that returns Mapping with result_type="reduce" * Add support for apply func that returns Mapping with axis=1 and result_type=None * Add support for apply func that returns Mapping with result_type="broadcast"
1 parent 328f623 commit b83f687

File tree

2 files changed

+99
-75
lines changed

2 files changed

+99
-75
lines changed

pandas-stubs/core/frame.pyi

+21-9
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ class DataFrame(NDFrame, OpsMixin):
10961096
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
10971097
axis: AxisTypeIndex = ...,
10981098
raw: _bool = ...,
1099-
result_type: Literal[None] = ...,
1099+
result_type: None = ...,
11001100
args=...,
11011101
**kwargs,
11021102
) -> DataFrame: ...
@@ -1106,10 +1106,22 @@ class DataFrame(NDFrame, OpsMixin):
11061106
f: Callable[..., S1],
11071107
axis: AxisTypeIndex = ...,
11081108
raw: _bool = ...,
1109-
result_type: Literal[None] = ...,
1109+
result_type: None = ...,
11101110
args=...,
11111111
**kwargs,
11121112
) -> Series[S1]: ...
1113+
# Since non-scalar type T is not supported in Series[T],
1114+
# we separate this overload from the above one
1115+
@overload
1116+
def apply(
1117+
self,
1118+
f: Callable[..., Mapping],
1119+
axis: AxisTypeIndex = ...,
1120+
raw: _bool = ...,
1121+
result_type: None = ...,
1122+
args=...,
1123+
**kwargs,
1124+
) -> Series: ...
11131125

11141126
# apply() overloads with keyword result_type, and axis does not matter
11151127
@overload
@@ -1126,7 +1138,7 @@ class DataFrame(NDFrame, OpsMixin):
11261138
@overload
11271139
def apply(
11281140
self,
1129-
f: Callable[..., ListLikeExceptSeriesAndStr | Series],
1141+
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Mapping],
11301142
axis: AxisType = ...,
11311143
raw: _bool = ...,
11321144
args=...,
@@ -1137,7 +1149,7 @@ class DataFrame(NDFrame, OpsMixin):
11371149
@overload
11381150
def apply(
11391151
self,
1140-
f: Callable[..., ListLikeExceptSeriesAndStr],
1152+
f: Callable[..., ListLikeExceptSeriesAndStr | Mapping],
11411153
axis: AxisType = ...,
11421154
raw: _bool = ...,
11431155
args=...,
@@ -1148,7 +1160,7 @@ class DataFrame(NDFrame, OpsMixin):
11481160
@overload
11491161
def apply(
11501162
self,
1151-
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Scalar],
1163+
f: Callable[..., ListLikeExceptSeriesAndStr | Series | Scalar | Mapping],
11521164
axis: AxisType = ...,
11531165
raw: _bool = ...,
11541166
args=...,
@@ -1176,7 +1188,7 @@ class DataFrame(NDFrame, OpsMixin):
11761188
self,
11771189
f: Callable[..., S1],
11781190
raw: _bool = ...,
1179-
result_type: Literal[None] = ...,
1191+
result_type: None = ...,
11801192
args=...,
11811193
*,
11821194
axis: AxisTypeColumn,
@@ -1185,9 +1197,9 @@ class DataFrame(NDFrame, OpsMixin):
11851197
@overload
11861198
def apply(
11871199
self,
1188-
f: Callable[..., ListLikeExceptSeriesAndStr],
1200+
f: Callable[..., ListLikeExceptSeriesAndStr | Mapping],
11891201
raw: _bool = ...,
1190-
result_type: Literal[None] = ...,
1202+
result_type: None = ...,
11911203
args=...,
11921204
*,
11931205
axis: AxisTypeColumn,
@@ -1198,7 +1210,7 @@ class DataFrame(NDFrame, OpsMixin):
11981210
self,
11991211
f: Callable[..., Series],
12001212
raw: _bool = ...,
1201-
result_type: Literal[None] = ...,
1213+
result_type: None = ...,
12021214
args=...,
12031215
*,
12041216
axis: AxisTypeColumn,

tests/test_frame.py

+78-66
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def returns_listlike_of_2(x: pd.Series) -> tuple[int, int]:
477477
def returns_listlike_of_3(x: pd.Series) -> tuple[int, int, int]:
478478
return (7, 8, 9)
479479

480+
def returns_dict(x: pd.Series) -> dict[str, int]:
481+
return {"col4": 7, "col5": 8}
482+
480483
# Misc checks
481484
check(assert_type(df.apply(np.exp), pd.DataFrame), pd.DataFrame)
482485
check(assert_type(df.apply(str), "pd.Series[str]"), pd.Series, str)
@@ -489,22 +492,15 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
489492

490493
# Check various return types for default result_type (None) with default axis (0)
491494
check(assert_type(df.apply(returns_scalar), "pd.Series[int]"), pd.Series, int)
492-
check(
493-
assert_type(df.apply(returns_series), pd.DataFrame),
494-
pd.DataFrame,
495-
)
496-
check(
497-
assert_type(df.apply(returns_listlike_of_3), pd.DataFrame),
498-
pd.DataFrame,
499-
)
495+
check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame)
496+
check(assert_type(df.apply(returns_listlike_of_3), pd.DataFrame), pd.DataFrame)
497+
check(assert_type(df.apply(returns_dict), pd.Series), pd.Series)
500498

501499
# Check various return types for result_type="expand" with default axis (0)
502500
check(
503-
assert_type(
504-
# Note that technically it does not make sense to pass a result_type of "expand" to a scalar return
505-
df.apply(returns_scalar, result_type="expand"),
506-
"pd.Series[int]",
507-
),
501+
# Note that technically it does not make sense
502+
# to pass a result_type of "expand" to a scalar return
503+
assert_type(df.apply(returns_scalar, result_type="expand"), "pd.Series[int]"),
508504
pd.Series,
509505
int,
510506
)
@@ -518,71 +514,47 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
518514
),
519515
pd.DataFrame,
520516
)
517+
check(
518+
assert_type(df.apply(returns_dict, result_type="expand"), pd.DataFrame),
519+
pd.DataFrame,
520+
)
521521

522522
# Check various return types for result_type="reduce" with default axis (0)
523523
check(
524-
assert_type(
525-
# Note that technically it does not make sense to pass a result_type of "reduce" to a scalar return
526-
df.apply(returns_scalar, result_type="reduce"),
527-
"pd.Series[int]",
528-
),
524+
# Note that technically it does not make sense
525+
# to pass a result_type of "reduce" to a scalar return
526+
assert_type(df.apply(returns_scalar, result_type="reduce"), "pd.Series[int]"),
529527
pd.Series,
530528
int,
531529
)
532530
check(
533-
assert_type(
534-
# Note that technically it does not make sense to pass a result_type of "reduce" to a series return
535-
df.apply(returns_series, result_type="reduce"),
536-
pd.Series,
537-
),
531+
# Note that technically it does not make sense
532+
# to pass a result_type of "reduce" to a series return
533+
assert_type(df.apply(returns_series, result_type="reduce"), pd.Series),
538534
pd.Series, # This technically returns a pd.Series[pd.Series], but typing does not support that
539535
)
540536
check(
541537
assert_type(df.apply(returns_listlike_of_3, result_type="reduce"), pd.Series),
542538
pd.Series,
543539
)
544-
545-
# Check various return types for result_type="broadcast" with default axis (0)
546540
check(
547-
assert_type(
548-
# Note that technically it does not make sense to pass a result_type of "broadcast" to a scalar return
549-
df.apply(returns_scalar, result_type="broadcast"),
550-
pd.DataFrame,
551-
),
552-
pd.DataFrame,
553-
)
554-
check(
555-
assert_type(df.apply(returns_series, result_type="broadcast"), pd.DataFrame),
556-
pd.DataFrame,
557-
)
558-
check(
559-
assert_type(
560-
# Can only broadcast a list-like of 2 elements, not 3, because there are 2 rows
561-
df.apply(returns_listlike_of_2, result_type="broadcast"),
562-
pd.DataFrame,
563-
),
564-
pd.DataFrame,
541+
assert_type(df.apply(returns_dict, result_type="reduce"), pd.Series), pd.Series
565542
)
566543

567544
# Check various return types for default result_type (None) with axis=1
568545
check(
569546
assert_type(df.apply(returns_scalar, axis=1), "pd.Series[int]"), pd.Series, int
570547
)
571-
check(
572-
assert_type(df.apply(returns_series, axis=1), pd.DataFrame),
573-
pd.DataFrame,
574-
)
575-
check(
576-
assert_type(df.apply(returns_listlike_of_3, axis=1), pd.Series),
577-
pd.Series,
578-
)
548+
check(assert_type(df.apply(returns_series, axis=1), pd.DataFrame), pd.DataFrame)
549+
check(assert_type(df.apply(returns_listlike_of_3, axis=1), pd.Series), pd.Series)
550+
check(assert_type(df.apply(returns_dict, axis=1), pd.Series), pd.Series)
579551

580552
# Check various return types for result_type="expand" with axis=1
581553
check(
554+
# Note that technically it does not make sense
555+
# to pass a result_type of "expand" to a scalar return
582556
assert_type(
583-
# Note that technically it does not make sense to pass a result_type of "expand" to a scalar return
584-
df.apply(returns_scalar, axis=1, result_type="expand"),
585-
"pd.Series[int]",
557+
df.apply(returns_scalar, axis=1, result_type="expand"), "pd.Series[int]"
586558
),
587559
pd.Series,
588560
int,
@@ -599,22 +571,26 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
599571
),
600572
pd.DataFrame,
601573
)
574+
check(
575+
assert_type(df.apply(returns_dict, axis=1, result_type="expand"), pd.DataFrame),
576+
pd.DataFrame,
577+
)
602578

603579
# Check various return types for result_type="reduce" with axis=1
604580
check(
581+
# Note that technically it does not make sense
582+
# to pass a result_type of "reduce" to a scalar return
605583
assert_type(
606-
# Note that technically it does not make sense to pass a result_type of "reduce" to a scalar return
607-
df.apply(returns_scalar, axis=1, result_type="reduce"),
608-
"pd.Series[int]",
584+
df.apply(returns_scalar, axis=1, result_type="reduce"), "pd.Series[int]"
609585
),
610586
pd.Series,
611587
int,
612588
)
613589
check(
590+
# Note that technically it does not make sense
591+
# to pass a result_type of "reduce" to a series return
614592
assert_type(
615-
# Note that technically it does not make sense to pass a result_type of "reduce" to a series return
616-
df.apply(returns_series, axis=1, result_type="reduce"),
617-
pd.DataFrame,
593+
df.apply(returns_series, axis=1, result_type="reduce"), pd.DataFrame
618594
),
619595
pd.DataFrame,
620596
)
@@ -624,13 +600,34 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
624600
),
625601
pd.Series,
626602
)
603+
check(
604+
assert_type(df.apply(returns_dict, axis=1, result_type="reduce"), pd.Series),
605+
pd.Series,
606+
)
627607

628-
# Check various return types for result_type="broadcast" with axis=1
608+
# Check various return types for result_type="broadcast" with axis=0 and axis=1
609+
check(
610+
# Note that technically it does not make sense
611+
# to pass a result_type of "broadcast" to a scalar return
612+
assert_type(df.apply(returns_scalar, result_type="broadcast"), pd.DataFrame),
613+
pd.DataFrame,
614+
)
615+
check(
616+
assert_type(df.apply(returns_series, result_type="broadcast"), pd.DataFrame),
617+
pd.DataFrame,
618+
)
629619
check(
620+
# Can only broadcast a list-like of 2 elements, not 3, because there are 2 rows
630621
assert_type(
631-
# Note that technicaly it does not make sense to pass a result_type of "broadcast" to a scalar return
632-
df.apply(returns_scalar, axis=1, result_type="broadcast"),
633-
pd.DataFrame,
622+
df.apply(returns_listlike_of_2, result_type="broadcast"), pd.DataFrame
623+
),
624+
pd.DataFrame,
625+
)
626+
check(
627+
# Note that technicaly it does not make sense
628+
# to pass a result_type of "broadcast" to a scalar return
629+
assert_type(
630+
df.apply(returns_scalar, axis=1, result_type="broadcast"), pd.DataFrame
634631
),
635632
pd.DataFrame,
636633
)
@@ -642,14 +639,29 @@ def gethead(s: pd.Series, y: int) -> pd.Series:
642639
)
643640
check(
644641
assert_type(
645-
# Can only broadcast a list-like of 3 elements, not 2, as there are 3 columns
642+
# Can only broadcast a list-like of 3 elements, not 2,
643+
# as there are 3 columns
646644
df.apply(returns_listlike_of_3, axis=1, result_type="broadcast"),
647645
pd.DataFrame,
648646
),
649647
pd.DataFrame,
650648
)
649+
# Since dicts will be assigned to elements of np.ndarray inside broadcasting,
650+
# we need to use a DataFrame with object dtype to make the assignment possible.
651+
df2 = pd.DataFrame({"col1": ["a", "b"], "col2": ["c", "d"]})
652+
check(
653+
assert_type(df2.apply(returns_dict, result_type="broadcast"), pd.DataFrame),
654+
pd.DataFrame,
655+
)
656+
check(
657+
assert_type(
658+
df2.apply(returns_dict, axis=1, result_type="broadcast"), pd.DataFrame
659+
),
660+
pd.DataFrame,
661+
)
651662

652-
# Test various other positional/keyword argument combinations to ensure all overloads are supported
663+
# Test various other positional/keyword argument combinations
664+
# to ensure all overloads are supported
653665
check(
654666
assert_type(df.apply(returns_scalar, axis=0), "pd.Series[int]"), pd.Series, int
655667
)

0 commit comments

Comments
 (0)