Skip to content

Commit 322a4cb

Browse files
committed
ENH: Improve dtypes
1 parent 617529a commit 322a4cb

File tree

3 files changed

+143
-68
lines changed

3 files changed

+143
-68
lines changed

pandas-stubs/core/arrays/string_.pyi

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
from typing import Literal
2+
3+
import numpy as np
4+
import pandas as pd
15
from pandas.core.arrays import PandasArray
26

37
from pandas._typing import type_t
48

59
from pandas.core.dtypes.base import ExtensionDtype
610

711
class StringDtype(ExtensionDtype):
8-
name: str = ...
9-
na_value = ...
12+
def __init__(self, storage: Literal["python", "pyarrow"] | None) -> None: ...
1013
@property
1114
def type(self) -> type_t: ...
12-
@classmethod
13-
def construct_array_type(cls) -> type_t[StringArray]: ...
1415
def __from_arrow__(self, array): ...
1516

1617
class StringArray(PandasArray):

pandas-stubs/core/dtypes/dtypes.pyi

+20-64
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from typing import (
2-
Any,
3-
Sequence,
4-
)
1+
import datetime as dt
2+
from typing import Any
53

4+
import numpy as np
65
from pandas.core.indexes.base import Index
6+
from pandas.core.series import Series
77

8-
from pandas._libs.tslibs import ( # , timezones as timezones
9-
Period as Period,
10-
Timestamp,
8+
from pandas._libs.tslibs import BaseOffset
9+
from pandas._typing import (
10+
Ordered,
11+
npt,
1112
)
12-
from pandas._typing import Ordered
1313

1414
from .base import ExtensionDtype as ExtensionDtype
1515

@@ -32,94 +32,50 @@ class PandasExtensionDtype(ExtensionDtype):
3232
@classmethod
3333
def reset_cache(cls) -> None: ...
3434

35-
class CategoricalDtypeType(type): ...
36-
3735
class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
38-
name: _str = ...
39-
type: type[CategoricalDtypeType] = ...
40-
kind: _str = ...
41-
str: _str = ...
42-
base = ...
4336
def __init__(
44-
self, categories: Sequence[Any] | None = ..., ordered: Ordered = ...
37+
self,
38+
categories: Series | Index | list[Any] | None = ...,
39+
ordered: Ordered = ...,
4540
) -> None: ...
46-
@classmethod
47-
def construct_from_string(cls, string: _str) -> CategoricalDtype: ...
4841
def __hash__(self) -> int: ...
4942
def __eq__(self, other) -> bool: ...
50-
@classmethod
51-
def construct_array_type(cls): ...
52-
@staticmethod
53-
def validate_ordered(ordered: Ordered) -> None: ...
54-
@staticmethod
55-
def validate_categories(categories, fastpath: bool = ...): ...
56-
def update_dtype(self, dtype: _str | CategoricalDtype) -> CategoricalDtype: ...
5743
@property
5844
def categories(self) -> Index: ...
5945
@property
6046
def ordered(self) -> Ordered: ...
6147

6248
class DatetimeTZDtype(PandasExtensionDtype):
63-
type: type[Timestamp] = ...
64-
kind: _str = ...
65-
str: _str = ...
66-
num: int = ...
67-
base = ...
68-
na_value = ...
69-
def __init__(self, unit: _str = ..., tz=...) -> None: ...
49+
def __init__(
50+
self, unit: _str = ..., tz: str | int | dt.tzinfo | None = ...
51+
) -> None: ...
7052
@property
7153
def unit(self): ...
7254
@property
7355
def tz(self): ...
74-
@classmethod
75-
def construct_array_type(cls): ...
76-
@classmethod
77-
def construct_from_string(cls, string: _str): ...
7856
@property
7957
def name(self) -> _str: ...
8058
def __hash__(self) -> int: ...
8159
def __eq__(self, other) -> bool: ...
8260

8361
class PeriodDtype(PandasExtensionDtype):
84-
type: type[Period] = ...
85-
kind: _str = ...
86-
str: _str = ...
87-
base = ...
88-
num: int = ...
89-
def __new__(cls, freq=...): ...
62+
def __new__(cls, freq: str | BaseOffset = ...): ...
63+
def __hash__(self) -> int: ...
64+
def __eq__(self, other) -> bool: ...
9065
@property
9166
def freq(self): ...
92-
@classmethod
93-
def construct_from_string(cls, string: _str): ...
9467
@property
9568
def name(self) -> _str: ...
9669
@property
9770
def na_value(self): ...
98-
def __hash__(self) -> int: ...
99-
def __eq__(self, other) -> bool: ...
100-
@classmethod
101-
def is_dtype(cls, dtype) -> bool: ...
102-
@classmethod
103-
def construct_array_type(cls): ...
10471
def __from_arrow__(self, array): ...
10572

10673
class IntervalDtype(PandasExtensionDtype):
107-
name: _str = ...
108-
kind: _str = ...
109-
str: _str = ...
110-
base = ...
111-
num: int = ...
112-
def __new__(cls, subtype=...): ...
74+
def __new__(cls, subtype: str | npt.DTypeLike | None = ...): ...
75+
def __hash__(self) -> int: ...
76+
def __eq__(self, other) -> bool: ...
11377
@property
11478
def subtype(self): ...
115-
@classmethod
116-
def construct_array_type(cls): ...
117-
@classmethod
118-
def construct_from_string(cls, string: _str): ...
11979
@property
12080
def type(self): ...
121-
def __hash__(self) -> int: ...
122-
def __eq__(self, other) -> bool: ...
123-
@classmethod
124-
def is_dtype(cls, dtype) -> bool: ...
12581
def __from_arrow__(self, array): ...

tests/test_dtypes.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from datetime import (
2+
timedelta,
3+
timezone,
4+
)
5+
6+
import numpy as np
7+
import pandas as pd
8+
import pyarrow as pa
9+
from typing_extensions import assert_type
10+
11+
from tests import check
12+
13+
from pandas.tseries.offsets import (
14+
BusinessDay,
15+
CustomBusinessDay,
16+
Day,
17+
)
18+
19+
20+
def test_datetimetz_dtype() -> None:
21+
check(
22+
assert_type(pd.DatetimeTZDtype(unit="ns", tz="UTC"), pd.DatetimeTZDtype),
23+
pd.DatetimeTZDtype,
24+
)
25+
check(
26+
assert_type(
27+
pd.DatetimeTZDtype(unit="ns", tz=timezone(timedelta(hours=1))),
28+
pd.DatetimeTZDtype,
29+
),
30+
pd.DatetimeTZDtype,
31+
)
32+
33+
34+
def test_period_dtype() -> None:
35+
check(assert_type(pd.PeriodDtype(freq="D"), pd.PeriodDtype), pd.PeriodDtype)
36+
check(assert_type(pd.PeriodDtype(freq=Day()), pd.PeriodDtype), pd.PeriodDtype)
37+
check(
38+
assert_type(pd.PeriodDtype(freq=BusinessDay()), pd.PeriodDtype), pd.PeriodDtype
39+
)
40+
check(
41+
assert_type(pd.PeriodDtype(freq=CustomBusinessDay()), pd.PeriodDtype),
42+
pd.PeriodDtype,
43+
)
44+
45+
46+
def test_interval_dtype() -> None:
47+
check(
48+
assert_type(
49+
pd.Interval(pd.Timestamp("2017-01-01"), pd.Timestamp("2017-01-02")),
50+
"pd.Interval[pd.Timestamp]",
51+
),
52+
pd.Interval,
53+
)
54+
check(
55+
assert_type(pd.Interval(1, 2, closed="left"), "pd.Interval[int]"), pd.Interval
56+
)
57+
check(
58+
assert_type(pd.Interval(1.0, 2.5, closed="right"), "pd.Interval[float]"),
59+
pd.Interval,
60+
)
61+
check(
62+
assert_type(pd.Interval(1.0, 2.5, closed="both"), "pd.Interval[float]"),
63+
pd.Interval,
64+
)
65+
check(
66+
assert_type(
67+
pd.Interval(
68+
pd.Timedelta("1 day"), pd.Timedelta("2 days"), closed="neither"
69+
),
70+
"pd.Interval[pd.Timedelta]",
71+
),
72+
pd.Interval,
73+
)
74+
75+
76+
def test_int64_dtype() -> None:
77+
check(assert_type(pd.Int64Dtype(), pd.Int64Dtype), pd.Int64Dtype)
78+
79+
80+
def test_categorical_dtype() -> None:
81+
check(
82+
assert_type(
83+
pd.CategoricalDtype(categories=["a", "b", "c"], ordered=True),
84+
pd.CategoricalDtype,
85+
),
86+
pd.CategoricalDtype,
87+
)
88+
check(
89+
assert_type(pd.CategoricalDtype(categories=[1, 2, 3]), pd.CategoricalDtype),
90+
pd.CategoricalDtype,
91+
)
92+
93+
94+
def test_sparse_dtype() -> None:
95+
check(assert_type(pd.SparseDtype(str), pd.SparseDtype), pd.SparseDtype)
96+
check(assert_type(pd.SparseDtype(complex), pd.SparseDtype), pd.SparseDtype)
97+
check(assert_type(pd.SparseDtype(bool), pd.SparseDtype), pd.SparseDtype)
98+
check(assert_type(pd.SparseDtype(int), pd.SparseDtype), pd.SparseDtype)
99+
check(assert_type(pd.SparseDtype(np.int64), pd.SparseDtype), pd.SparseDtype)
100+
check(assert_type(pd.SparseDtype(str), pd.SparseDtype), pd.SparseDtype)
101+
check(assert_type(pd.SparseDtype(float), pd.SparseDtype), pd.SparseDtype)
102+
check(assert_type(pd.SparseDtype(np.datetime64), pd.SparseDtype), pd.SparseDtype)
103+
check(assert_type(pd.SparseDtype(np.timedelta64), pd.SparseDtype), pd.SparseDtype)
104+
check(assert_type(pd.SparseDtype("datetime64"), pd.SparseDtype), pd.SparseDtype)
105+
check(assert_type(pd.SparseDtype(), pd.SparseDtype), pd.SparseDtype)
106+
107+
108+
def test_string_dtype() -> None:
109+
check(assert_type(pd.StringDtype("pyarrow"), pd.StringDtype), pd.StringDtype)
110+
check(assert_type(pd.StringDtype("python"), pd.StringDtype), pd.StringDtype)
111+
112+
113+
def test_boolean_dtype() -> None:
114+
check(assert_type(pd.BooleanDtype(), pd.BooleanDtype), pd.BooleanDtype)
115+
116+
117+
def test_arrow_dtype() -> None:
118+
check(assert_type(pd.ArrowDtype(pa.int64()), pd.ArrowDtype), pd.ArrowDtype)

0 commit comments

Comments
 (0)