Skip to content

Dataframe loc #575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ class _LocIndexerFrame(_LocIndexer):
| Callable[[DataFrame], IndexType | MaskType | list[HashableT]]
| list[HashableT]
| tuple[
IndexType | MaskType | list[HashableT] | Hashable,
list[HashableT] | slice | Series[bool] | Callable,
Iterable[HashableT] | slice | Hashable,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterable[HashableT] is too wide, as it will match a plain string, which, if supplied, would return a Series.

So please put back IndexType | MaskType | list[HashableT] and replace Hashable with _IndexSliceTuple | Callable

Including slice is fine.

If you make the suggested change, that will cause the test test_frame.py:test_loc_slice() to fail, but I now realize that the expression used there is ambiguous:

>>> df1 = pd.DataFrame(
...         {"x": [1, 2, 3, 4]},
...         index=pd.MultiIndex.from_product([[1, 2], ["a", "b"]], names=["num", "let"]),
...     )
>>> df1.loc[1, :]
     x
let
a    1
b    2
>>> df2 = pd.DataFrame({"x": [1,2,3,4]}, index=[10, 20, 30, 40])
>>> df2.loc[10, :]
x    1
Name: 10, dtype: int64

So the first argument as an integer could return a DataFrame or Series, dependent on whether the underlying index is a regular Index or MultiIndex

The solution is then to add another overload in _LocIndexerFrame.__getitem__():

    @overload
    def __getitem__(self, idx: tuple[ScalarT, slice]) -> Series | DataFrame: ...

Then modify the test in test_index_slice() to check that the type is Union[pd.Series, pd.DataFrame], and add another test corresponding to df2 above.

Copy link
Contributor Author

@randolf-scholz randolf-scholz Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching str is fine here, because the second component of the tuple ensures multiple columns are selected.

Both df.loc["row", ["col1", "col2", "col3"]] and df.loc[["r", "o", "w"], ["col1", "col2", "col3"]] return DataFrame which is the only thing this overload ensures.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching str is fine here, because the second component of the tuple ensures multiple columns are selected.

Both df.loc["row", ["col1", "col2", "col3"]] and df.loc[["r", "o", "w"], ["col1", "col2", "col3"]] return DataFrame which is the only thing this overload ensures.

No, you are incorrect. The first example could create a Series:

>>> import pandas as pd
>>> df = pd.DataFrame({"x":[1,2,3], "y":[4,5,6]}, index=["a", "b",
 "c"])
>>> df
   x  y
a  1  4
b  2  5
c  3  6
>>> df.loc["a", ["x", "y"]]
x    1
y    4
Name: a, dtype: int64
>>> type(df.loc["a", ["x", "y"]])
<class 'pandas.core.series.Series'>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need HashableT for Iterable as it is covariant: Iterable[Hashable]

list[HashableT] | Series[bool] | slice | Callable,
],
) -> DataFrame: ...
@overload
Expand Down
6 changes: 6 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,3 +2483,9 @@ def test_xs_frame_new() -> None:
s2 = df.xs("num_wings", axis=1)
check(assert_type(s1, Union[pd.Series, pd.DataFrame]), pd.DataFrame)
check(assert_type(s2, Union[pd.Series, pd.DataFrame]), pd.Series)


def test_loc_tuple_slice_list() -> None:
"""Test DataFrame.loc[index, columns]."""
foo = pd.DataFrame(np.random.rand(10, 3), columns=["a", "b", "c"])
check(assert_type(foo.loc[4:5, ["a", "b"]], pd.DataFrame), pd.DataFrame)