Skip to content

Commit 2111a97

Browse files
bashtageKevin Sheppard
and
Kevin Sheppard
authored
ENH: Add more 1.5.0 features (#338)
* ENH: Add more 1.5.0 features * ENH: Add union to RangeIndex * ENH: Add check_like to assert_series_equal * ENH: Add ArrowDtype * TST: Add tests for 1.5.0 additions * BUG: Add defaultdict throughout * REF: Move test to better location * TYP: Remove specific types from defaultdict Co-authored-by: Kevin Sheppard <[email protected]>
1 parent 25e72c3 commit 2111a97

File tree

16 files changed

+159
-18
lines changed

16 files changed

+159
-18
lines changed

pandas-stubs/__init__.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ from ._config import (
1717
)
1818
from .core.api import (
1919
NA as NA,
20+
ArrowDtype as ArrowDtype,
2021
BooleanDtype as BooleanDtype,
2122
Categorical as Categorical,
2223
CategoricalDtype as CategoricalDtype,

pandas-stubs/_libs/tslibs/offsets.pyi

+27-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,33 @@ class CustomBusinessHour(BusinessHour):
225225

226226
class CustomBusinessMonthEnd(_CustomBusinessMonth): ...
227227
class CustomBusinessMonthBegin(_CustomBusinessMonth): ...
228-
class DateOffset(RelativeDeltaOffset): ...
228+
229+
class DateOffset(RelativeDeltaOffset):
230+
def __init__(
231+
self,
232+
*,
233+
n: int = ...,
234+
normalize: bool = ...,
235+
years: int = ...,
236+
months: int = ...,
237+
weeks: int = ...,
238+
days: int = ...,
239+
hours: int = ...,
240+
minutes: int = ...,
241+
seconds: int = ...,
242+
milliseconds: int = ...,
243+
microseconds: int = ...,
244+
nanoseconds: int = ...,
245+
year: int = ...,
246+
month: int = ...,
247+
day: int = ...,
248+
weekday: int = ...,
249+
hour: int = ...,
250+
minute: int = ...,
251+
second: int = ...,
252+
microsecond: int = ...,
253+
nanosecond: int = ...,
254+
): ...
229255

230256
BDay = BusinessDay
231257
BMonthEnd = BusinessMonthEnd

pandas-stubs/_testing/__init__.pyi

+25-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from typing import (
33
Any,
44
Generator,
55
Literal,
6+
overload,
67
)
78

89
from pandas import (
@@ -54,6 +55,7 @@ def assert_extension_array_equal(
5455
check_less_precise: bool = ...,
5556
check_exact: bool = ...,
5657
) -> None: ...
58+
@overload
5759
def assert_series_equal(
5860
left: Series,
5961
right: Series,
@@ -71,7 +73,29 @@ def assert_series_equal(
7173
atol: float = ...,
7274
obj: str = ...,
7375
*,
74-
check_index: bool = ...,
76+
check_index: Literal[False],
77+
check_like: Literal[False],
78+
) -> None: ...
79+
@overload
80+
def assert_series_equal(
81+
left: Series,
82+
right: Series,
83+
check_dtype: bool = ...,
84+
check_index_type: bool | str = ...,
85+
check_series_type: bool = ...,
86+
check_names: bool = ...,
87+
check_exact: bool = ...,
88+
check_datetimelike_compat: bool = ...,
89+
check_categorical: bool = ...,
90+
check_category_order: bool = ...,
91+
check_freq: bool = ...,
92+
check_flags: bool = ...,
93+
rtol: float = ...,
94+
atol: float = ...,
95+
obj: str = ...,
96+
*,
97+
check_index: Literal[True] = ...,
98+
check_like: bool = ...,
7599
) -> None: ...
76100
def assert_frame_equal(
77101
left: DataFrame,

pandas-stubs/core/api.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ from pandas.core.algorithms import (
44
value_counts as value_counts,
55
)
66
from pandas.core.arrays import Categorical as Categorical
7+
from pandas.core.arrays.arrow.dtype import ArrowDtype as ArrowDtype
78
from pandas.core.arrays.boolean import BooleanDtype as BooleanDtype
89
from pandas.core.arrays.floating import (
910
Float32Dtype as Float32Dtype,
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import numpy as np
2+
import pyarrow as pa
3+
4+
from pandas.core.dtypes.base import StorageExtensionDtype
5+
6+
class ArrowDtype(StorageExtensionDtype):
7+
pyarrow_dtype: pa.DataType
8+
def __init__(self, pyarrow_dtype: pa.DataType) -> None: ...

pandas-stubs/core/dtypes/base.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ class ExtensionDtype:
2222
def construct_from_string(cls, string: str): ...
2323
@classmethod
2424
def is_dtype(cls, dtype) -> bool: ...
25+
26+
class StorageExtensionDtype(ExtensionDtype): ...

pandas-stubs/core/indexes/base.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ from pandas._typing import (
2828
Dtype,
2929
DtypeArg,
3030
DtypeObj,
31+
HashableT,
3132
IndexT,
3233
Label,
3334
Level,
@@ -148,7 +149,7 @@ class Index(IndexOpsMixin, PandasObject):
148149
def __neg__(self: IndexT) -> IndexT: ...
149150
def __nonzero__(self) -> None: ...
150151
__bool__ = ...
151-
def union(self, other: list[T1] | Index, sort=...) -> Index: ...
152+
def union(self, other: list[HashableT] | Index, sort=...) -> Index: ...
152153
def intersection(self, other: list[T1] | Index, sort: bool = ...) -> Index: ...
153154
def difference(self, other: list | Index) -> Index: ...
154155
def symmetric_difference(

pandas-stubs/core/indexes/multi.pyi

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ from typing import (
66
)
77

88
import numpy as np
9+
import pandas as pd
910
from pandas.core.indexes.base import Index
1011

1112
from pandas._typing import (
1213
T1,
1314
DtypeArg,
15+
HashableT,
1416
np_ndarray_bool,
1517
)
1618

@@ -88,7 +90,12 @@ class MultiIndex(Index):
8890
def get_value(self, series, key): ...
8991
def get_level_values(self, level: str | int) -> Index: ...
9092
def unique(self, level=...): ...
91-
def to_frame(self, index: bool = ..., name=...): ...
93+
def to_frame(
94+
self,
95+
index: bool = ...,
96+
name: list[HashableT] = ...,
97+
allow_duplicates: bool = ...,
98+
) -> pd.DataFrame: ...
9299
def to_flat_index(self): ...
93100
@property
94101
def is_all_dates(self) -> bool: ...

pandas-stubs/core/indexes/range.pyi

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import numpy as np
2+
from pandas.core.indexes.base import Index
23
from pandas.core.indexes.numeric import Int64Index
34

4-
from pandas._typing import npt
5+
from pandas._typing import (
6+
HashableT,
7+
npt,
8+
)
59

610
class RangeIndex(Int64Index):
711
def __new__(
@@ -70,3 +74,6 @@ class RangeIndex(Int64Index):
7074
def __floordiv__(self, other): ...
7175
def all(self) -> bool: ...
7276
def any(self) -> bool: ...
77+
def union(
78+
self, other: list[HashableT] | Index, sort=...
79+
) -> Index | Int64Index | RangeIndex: ...

pandas-stubs/core/window/ewm.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ExponentialMovingWindow(BaseWindow[NDFrameT], Generic[NDFrameT]):
3535
adjust: bool = ...,
3636
ignore_na: bool = ...,
3737
axis: Axis = ...,
38-
times: str | np.ndarray | Series | None = ...,
38+
times: str | np.ndarray | Series | None | np.timedelta64 = ...,
3939
method: CalculationMethod = ...,
4040
) -> None: ...
4141
@overload

pandas-stubs/io/clipboards.pyi

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import csv
23
from typing import (
34
Any,
@@ -38,7 +39,7 @@ def read_clipboard(
3839
| npt.NDArray
3940
| Callable[[str], bool]
4041
| None = ...,
41-
dtype: DtypeArg | None = ...,
42+
dtype: DtypeArg | defaultdict | None = ...,
4243
engine: CSVEngine | None = ...,
4344
converters: dict[int | str, Callable[[str], Any]] = ...,
4445
true_values: list[str] = ...,
@@ -101,7 +102,7 @@ def read_clipboard(
101102
| npt.NDArray
102103
| Callable[[str], bool]
103104
| None = ...,
104-
dtype: DtypeArg | None = ...,
105+
dtype: DtypeArg | defaultdict | None = ...,
105106
engine: CSVEngine | None = ...,
106107
converters: dict[int | str, Callable[[str], Any]] = ...,
107108
true_values: list[str] = ...,
@@ -164,7 +165,7 @@ def read_clipboard(
164165
| npt.NDArray
165166
| Callable[[str], bool]
166167
| None = ...,
167-
dtype: DtypeArg | None = ...,
168+
dtype: DtypeArg | defaultdict | None = ...,
168169
engine: CSVEngine | None = ...,
169170
converters: dict[int | str, Callable[[str], Any]] = ...,
170171
true_values: list[str] = ...,

pandas-stubs/io/parsers/readers.pyi

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from collections import abc
1+
from collections import (
2+
abc,
3+
defaultdict,
4+
)
25
import csv
36
from types import TracebackType
47
from typing import (
@@ -44,7 +47,7 @@ def read_csv(
4447
| npt.NDArray
4548
| Callable[[str], bool]
4649
| None = ...,
47-
dtype: DtypeArg | None = ...,
50+
dtype: DtypeArg | defaultdict | None = ...,
4851
engine: CSVEngine | None = ...,
4952
converters: dict[int | str, Callable[[str], Any]] = ...,
5053
true_values: list[str] = ...,
@@ -107,7 +110,7 @@ def read_csv(
107110
| npt.NDArray
108111
| Callable[[str], bool]
109112
| None = ...,
110-
dtype: DtypeArg | None = ...,
113+
dtype: DtypeArg | defaultdict | None = ...,
111114
engine: CSVEngine | None = ...,
112115
converters: dict[int | str, Callable[[str], Any]] = ...,
113116
true_values: list[str] = ...,
@@ -170,7 +173,7 @@ def read_csv(
170173
| npt.NDArray
171174
| Callable[[str], bool]
172175
| None = ...,
173-
dtype: DtypeArg | None = ...,
176+
dtype: DtypeArg | defaultdict | None = ...,
174177
engine: CSVEngine | None = ...,
175178
converters: dict[int | str, Callable[[str], Any]] = ...,
176179
true_values: list[str] = ...,
@@ -233,7 +236,7 @@ def read_table(
233236
| npt.NDArray
234237
| Callable[[str], bool]
235238
| None = ...,
236-
dtype: DtypeArg | None = ...,
239+
dtype: DtypeArg | defaultdict | None = ...,
237240
engine: CSVEngine | None = ...,
238241
converters: dict[int | str, Callable[[str], Any]] = ...,
239242
true_values: list[str] = ...,
@@ -296,7 +299,7 @@ def read_table(
296299
| npt.NDArray
297300
| Callable[[str], bool]
298301
| None = ...,
299-
dtype: DtypeArg | None = ...,
302+
dtype: DtypeArg | defaultdict | None = ...,
300303
engine: CSVEngine | None = ...,
301304
converters: dict[int | str, Callable[[str], Any]] = ...,
302305
true_values: list[str] = ...,
@@ -359,7 +362,7 @@ def read_table(
359362
| npt.NDArray
360363
| Callable[[str], bool]
361364
| None = ...,
362-
dtype: DtypeArg | None = ...,
365+
dtype: DtypeArg | defaultdict | None = ...,
363366
engine: CSVEngine | None = ...,
364367
converters: dict[int | str, Callable[[str], Any]] = ...,
365368
true_values: list[str] = ...,

tests/test_indexes.py

+32
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3+
from typing import Union
4+
35
import numpy as np
46
from numpy import typing as npt
57
import pandas as pd
68
from pandas.core.indexes.numeric import NumericIndex
9+
import pytest
710
from typing_extensions import assert_type
811

912
from tests import check
@@ -31,6 +34,10 @@ def test_index_astype() -> None:
3134
mi = pd.MultiIndex.from_product([["a", "b"], ["c", "d"]], names=["ab", "cd"])
3235
mia = mi.astype(object) # object is only valid parameter for MultiIndex.astype()
3336
check(assert_type(mia, pd.MultiIndex), pd.MultiIndex)
37+
check(
38+
assert_type(mi.to_frame(name=[3, 7], allow_duplicates=True), pd.DataFrame),
39+
pd.DataFrame,
40+
)
3441

3542

3643
def test_multiindex_get_level_values() -> None:
@@ -148,3 +155,28 @@ def test_index_relops() -> None:
148155
check(assert_type(ind >= 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_)
149156
check(assert_type(ind < 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_)
150157
check(assert_type(ind > 2, npt.NDArray[np.bool_]), np.ndarray, np.bool_)
158+
159+
160+
def test_range_index_union():
161+
with pytest.warns(FutureWarning, match="pandas.Int64Index"):
162+
check(
163+
assert_type(
164+
pd.RangeIndex(0, 10).union(pd.RangeIndex(10, 20)),
165+
Union[pd.Index, pd.Int64Index, pd.RangeIndex],
166+
),
167+
pd.RangeIndex,
168+
)
169+
check(
170+
assert_type(
171+
pd.RangeIndex(0, 10).union([11, 12, 13]),
172+
Union[pd.Index, pd.Int64Index, pd.RangeIndex],
173+
),
174+
pd.Int64Index,
175+
)
176+
check(
177+
assert_type(
178+
pd.RangeIndex(0, 10).union(["a", "b", "c"]),
179+
Union[pd.Index, pd.Int64Index, pd.RangeIndex],
180+
),
181+
pd.Index,
182+
)

tests/test_io.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import csv
23
import io
34
import os.path
@@ -208,6 +209,10 @@ def test_clipboard():
208209
check(assert_type(read_clipboard(), DataFrame), DataFrame)
209210
check(assert_type(read_clipboard(iterator=False), DataFrame), DataFrame)
210211
check(assert_type(read_clipboard(chunksize=None), DataFrame), DataFrame)
212+
check(
213+
assert_type(read_clipboard(dtype=defaultdict(lambda: "f8")), DataFrame),
214+
DataFrame,
215+
)
211216

212217

213218
def test_clipboard_iterator():
@@ -425,6 +430,10 @@ def test_read_csv():
425430
check(assert_type(read_csv(sio), DataFrame), DataFrame)
426431
check(assert_type(read_csv(path, iterator=False), DataFrame), DataFrame)
427432
check(assert_type(read_csv(path, chunksize=None), DataFrame), DataFrame)
433+
check(
434+
assert_type(read_csv(path, dtype=defaultdict(lambda: "f8")), DataFrame),
435+
DataFrame,
436+
)
428437

429438

430439
def test_read_csv_iterator():
@@ -489,6 +498,10 @@ def test_read_table():
489498
check(assert_type(read_table(path), DataFrame), DataFrame)
490499
check(assert_type(read_table(path, iterator=False), DataFrame), DataFrame)
491500
check(assert_type(read_table(path, chunksize=None), DataFrame), DataFrame)
501+
check(
502+
assert_type(read_table(path, dtype=defaultdict(lambda: "f8")), DataFrame),
503+
DataFrame,
504+
)
492505

493506

494507
def test_read_table_iterator():

tests/test_pandas.py

+14
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from numpy import typing as npt
1212
import pandas as pd
1313
from pandas.api.extensions import ExtensionArray
14+
import pytest
1415
from typing_extensions import assert_type
1516

1617
from tests import check
@@ -261,3 +262,16 @@ def test_crosstab() -> None:
261262
),
262263
pd.DataFrame,
263264
)
265+
266+
267+
def test_arrow_dtype() -> None:
268+
pytest.importorskip("pyarrow")
269+
270+
import pyarrow as pa
271+
272+
check(
273+
assert_type(
274+
pd.ArrowDtype(pa.timestamp("s", tz="America/New_York")), pd.ArrowDtype
275+
),
276+
pd.ArrowDtype,
277+
)

0 commit comments

Comments
 (0)