Skip to content

Commit 987c12b

Browse files
authored
Fixup typing of dtypes (#321)
* fixup dtypes typing * noop
1 parent 7be00b6 commit 987c12b

File tree

2 files changed

+54
-36
lines changed

2 files changed

+54
-36
lines changed

spec/API_specification/dataframe_api/typing.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,6 @@
1717
if TYPE_CHECKING:
1818
from collections.abc import Sequence
1919

20-
from .dtypes import (
21-
Bool,
22-
Date,
23-
Datetime,
24-
Duration,
25-
Float32,
26-
Float64,
27-
Int8,
28-
Int16,
29-
Int32,
30-
Int64,
31-
String,
32-
UInt8,
33-
UInt16,
34-
UInt32,
35-
UInt64,
36-
)
37-
38-
DType = Union[
39-
Bool,
40-
Float64,
41-
Float32,
42-
Int64,
43-
Int32,
44-
Int16,
45-
Int8,
46-
UInt64,
47-
UInt32,
48-
UInt16,
49-
UInt8,
50-
String,
51-
Date,
52-
Datetime,
53-
Duration,
54-
]
5520

5621
# Type alias: Mypy needs Any, but for readability we need to make clear this
5722
# is a Python scalar (i.e., an instance of `bool`, `int`, `float`, `str`, etc.)
@@ -104,7 +69,14 @@ class Datetime:
10469
def __init__( # noqa: ANN204
10570
self,
10671
time_unit: Literal["ms", "us"],
107-
time_zone: str | None,
72+
time_zone: str | None = None,
73+
):
74+
...
75+
76+
class Duration:
77+
def __init__( # noqa: ANN204
78+
self,
79+
time_unit: Literal["ms", "us"],
10880
):
10981
...
11082

@@ -155,6 +127,25 @@ def date(self, year: int, month: int, day: int) -> Scalar:
155127
...
156128

157129

130+
DType = Union[
131+
Namespace.Bool,
132+
Namespace.Float64,
133+
Namespace.Float32,
134+
Namespace.Int64,
135+
Namespace.Int32,
136+
Namespace.Int16,
137+
Namespace.Int8,
138+
Namespace.UInt64,
139+
Namespace.UInt32,
140+
Namespace.UInt16,
141+
Namespace.UInt8,
142+
Namespace.String,
143+
Namespace.Date,
144+
Namespace.Datetime,
145+
Namespace.Duration,
146+
]
147+
148+
158149
class SupportsDataFrameAPI(Protocol):
159150
def __dataframe_consortium_standard__(
160151
self,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Callable
4+
5+
if TYPE_CHECKING:
6+
from dataframe_api.typing import SupportsDataFrameAPI
7+
8+
some_array_function: Callable[[Any], Any]
9+
10+
11+
def main(df_raw: SupportsDataFrameAPI) -> SupportsDataFrameAPI:
12+
df = df_raw.__dataframe_consortium_standard__(api_version="2023-11.beta").persist()
13+
namespace = df.__dataframe_namespace__()
14+
df = df.select(
15+
*[
16+
col_name
17+
for col_name in df.column_names
18+
if isinstance(df.col(col_name).dtype, namespace.Int64)
19+
],
20+
)
21+
arr = df.to_array(namespace.Int64())
22+
arr = some_array_function(arr)
23+
df = namespace.dataframe_from_2d_array(
24+
arr,
25+
schema={"a": df.col("a").dtype, "b": namespace.Float64()},
26+
)
27+
return df.dataframe

0 commit comments

Comments
 (0)