Skip to content

Commit bfa107b

Browse files
authored
allow np.uint64 to be used in indexing. Support numpy 1.24.1 (#510)
1 parent 261eabb commit bfa107b

File tree

5 files changed

+21
-6
lines changed

5 files changed

+21
-6
lines changed

pandas-stubs/_typing.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ np_ndarray_anyint: TypeAlias = npt.NDArray[np.integer]
189189
np_ndarray_bool: TypeAlias = npt.NDArray[np.bool_]
190190
np_ndarray_str: TypeAlias = npt.NDArray[np.str_]
191191

192-
IndexType: TypeAlias = Union[slice, np_ndarray_int64, Index, list[int], Series[int]]
192+
IndexType: TypeAlias = Union[slice, np_ndarray_anyint, Index, list[int], Series[int]]
193193
MaskType: TypeAlias = Union[Series[bool], np_ndarray_bool, list[bool]]
194194
# Scratch types for generics
195195
S1 = TypeVar(

pandas-stubs/core/indexes/base.pyi

+3-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ from pandas._typing import (
3838
Level,
3939
NaPosition,
4040
Scalar,
41+
np_ndarray_anyint,
4142
np_ndarray_bool,
4243
np_ndarray_int64,
4344
type_t,
@@ -192,10 +193,10 @@ class Index(IndexOpsMixin, PandasObject):
192193
@overload
193194
def __getitem__(
194195
self: IndexT,
195-
idx: slice | np_ndarray_int64 | Index | Series[bool] | np_ndarray_bool,
196+
idx: slice | np_ndarray_anyint | Index | Series[bool] | np_ndarray_bool,
196197
) -> IndexT: ...
197198
@overload
198-
def __getitem__(self, idx: int | tuple[np_ndarray_int64, ...]) -> Scalar: ...
199+
def __getitem__(self, idx: int | tuple[np_ndarray_anyint, ...]) -> Scalar: ...
199200
def append(self, other): ...
200201
def putmask(self, mask, value): ...
201202
def equals(self, other) -> bool: ...

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pyright = ">=1.1.286"
4242
poethepoet = ">=0.16.5"
4343
loguru = ">=0.6.0"
4444
pandas = "1.5.2"
45-
numpy = "<=1.23.5"
45+
numpy = ">=1.24.1"
4646
typing-extensions = ">=4.2.0"
4747
matplotlib = ">=3.5.1"
4848
pre-commit = ">=2.19.0"

tests/test_frame.py

+14
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
import numpy as np
28+
import numpy.typing as npt
2829
import pandas as pd
2930
from pandas._testing import (
3031
ensure_clean,
@@ -2363,3 +2364,16 @@ def test_frame_dropna_subset() -> None:
23632364
assert_type(df.dropna(subset=df.columns.drop("col1")), pd.DataFrame),
23642365
pd.DataFrame,
23652366
)
2367+
2368+
2369+
def test_npint_loc_indexer() -> None:
2370+
# GH 508
2371+
2372+
df = pd.DataFrame(dict(x=[1, 2, 3]), index=np.array([10, 20, 30], dtype="uint64"))
2373+
2374+
def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]) -> pd.DataFrame:
2375+
df2 = df.loc[key]
2376+
return df2
2377+
2378+
a: npt.NDArray[np.uint64] = np.array([10, 30], dtype="uint64")
2379+
check(assert_type(get_NDArray(df, a), pd.DataFrame), pd.DataFrame)

tests/test_pandas.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pandas as pd
1414
from pandas import Grouper
1515
from pandas.api.extensions import ExtensionArray
16+
from pandas.util.version import Version
1617
import pytest
1718
from typing_extensions import assert_type
1819

@@ -1705,7 +1706,7 @@ def test_pivot_table() -> None:
17051706
),
17061707
pd.DataFrame,
17071708
)
1708-
with pytest.warns(np.VisibleDeprecationWarning):
1709+
if Version(np.__version__) <= Version("1.23.5"):
17091710
check(
17101711
assert_type(
17111712
pd.pivot_table(
@@ -1719,7 +1720,6 @@ def test_pivot_table() -> None:
17191720
),
17201721
pd.DataFrame,
17211722
)
1722-
with pytest.warns(np.VisibleDeprecationWarning):
17231723
check(
17241724
assert_type(
17251725
pd.pivot_table(

0 commit comments

Comments
 (0)