Skip to content

Commit 09553b0

Browse files
authored
added return types for "SeriesGroupBy" and "DataFrameGroupBy" (#455)
* added return types for "SeriesGroupBy" and "DataFrameGroupBy" with tests * added function callable in transform * added 'str' to ' func' in transform * added the tests with a string for transform
1 parent c74909e commit 09553b0

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

pandas-stubs/core/groupby/generic.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class SeriesGroupBy(GroupBy, Generic[S1]):
6161
def agg(self, func: list[AggFuncTypeBase], *args, **kwargs) -> DataFrame: ...
6262
@overload
6363
def agg(self, func: AggFuncTypeBase, *args, **kwargs) -> Series: ...
64-
def transform(self, func, *args, **kwargs): ...
64+
def transform(self, func: Callable | str, *args, **kwargs) -> Series: ...
6565
def filter(self, func, dropna: bool = ..., *args, **kwargs): ...
6666
def nunique(self, dropna: bool = ...) -> Series: ...
6767
def describe(self, **kwargs) -> DataFrame: ...
@@ -166,7 +166,7 @@ class DataFrameGroupBy(GroupBy):
166166
) -> DataFrame: ...
167167
def aggregate(self, arg: AggFuncTypeFrame = ..., *args, **kwargs) -> DataFrame: ...
168168
agg = aggregate
169-
def transform(self, func, *args, **kwargs): ...
169+
def transform(self, func: Callable | str, *args, **kwargs) -> DataFrame: ...
170170
def filter(
171171
self, func: Callable, dropna: bool = ..., *args, **kwargs
172172
) -> DataFrame: ...

tests/test_frame.py

+26
Original file line numberDiff line numberDiff line change
@@ -2071,3 +2071,29 @@ def test_setitem_none() -> None:
20712071
sb = pd.Series([1, 2, 3], dtype=int)
20722072
sb.loc["y"] = None
20732073
sb.iloc[0] = None
2074+
2075+
2076+
def test_groupby_and_transform() -> None:
2077+
df = pd.DataFrame(
2078+
{
2079+
"A": ["foo", "bar", "foo", "bar", "foo", "bar"],
2080+
"B": ["one", "one", "two", "three", "two", "two"],
2081+
"C": [1, 5, 5, 2, 5, 5],
2082+
"D": [2.0, 5.0, 8.0, 1.0, 2.0, 9.0],
2083+
}
2084+
)
2085+
ser = pd.Series(
2086+
[390.0, 350.0, 30.0, 20.0],
2087+
index=["Falcon", "Falcon", "Parrot", "Parrot"],
2088+
name="Max Speed",
2089+
)
2090+
grouped = df.groupby("A")[["C", "D"]]
2091+
grouped1 = ser.groupby(ser > 100)
2092+
c1 = grouped.transform("sum")
2093+
c2 = grouped.transform(lambda x: (x - x.mean()) / x.std())
2094+
c3 = grouped1.transform("cumsum")
2095+
c4 = grouped1.transform(lambda x: x.max() - x.min())
2096+
check(assert_type(c1, pd.DataFrame), pd.DataFrame)
2097+
check(assert_type(c2, pd.DataFrame), pd.DataFrame)
2098+
check(assert_type(c3, pd.Series), pd.Series)
2099+
check(assert_type(c4, pd.Series), pd.Series)

0 commit comments

Comments
 (0)