Skip to content

Commit 005852f

Browse files
committed
Remove library-specific stuff from common/_typing.py
1 parent 420c0da commit 005852f

File tree

3 files changed

+92
-39
lines changed

3 files changed

+92
-39
lines changed

array_api_compat/common/_typing.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,20 @@
11
from __future__ import annotations
22

33
__all__ = [
4-
"ndarray",
5-
"Device",
6-
"Dtype",
74
"NestedSequence",
85
"SupportsBufferProtocol",
96
]
107

11-
import sys
128
from typing import (
139
Any,
14-
Literal,
15-
Union,
16-
TYPE_CHECKING,
1710
TypeVar,
1811
Protocol,
1912
)
2013

21-
from numpy import (
22-
ndarray,
23-
dtype,
24-
int8,
25-
int16,
26-
int32,
27-
int64,
28-
uint8,
29-
uint16,
30-
uint32,
31-
uint64,
32-
float32,
33-
float64,
34-
)
35-
3614
_T_co = TypeVar("_T_co", covariant=True)
3715

3816
class NestedSequence(Protocol[_T_co]):
3917
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
4018
def __len__(self, /) -> int: ...
4119

42-
Device = Literal["cpu"]
43-
if TYPE_CHECKING or sys.version_info >= (3, 9):
44-
Dtype = dtype[Union[
45-
int8,
46-
int16,
47-
int32,
48-
int64,
49-
uint8,
50-
uint16,
51-
uint32,
52-
uint64,
53-
float32,
54-
float64,
55-
]]
56-
else:
57-
Dtype = dtype
58-
5920
SupportsBufferProtocol = Any

array_api_compat/cupy/_typing.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
__all__ = [
4+
"ndarray",
5+
"Device",
6+
"Dtype",
7+
]
8+
9+
import sys
10+
from typing import (
11+
Union,
12+
TYPE_CHECKING,
13+
)
14+
15+
from cupy import (
16+
ndarray,
17+
dtype,
18+
int8,
19+
int16,
20+
int32,
21+
int64,
22+
uint8,
23+
uint16,
24+
uint32,
25+
uint64,
26+
float32,
27+
float64,
28+
)
29+
30+
from cupy.cuda.device import Device
31+
32+
if TYPE_CHECKING or sys.version_info >= (3, 9):
33+
Dtype = dtype[Union[
34+
int8,
35+
int16,
36+
int32,
37+
int64,
38+
uint8,
39+
uint16,
40+
uint32,
41+
uint64,
42+
float32,
43+
float64,
44+
]]
45+
else:
46+
Dtype = dtype

array_api_compat/numpy/_typing.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
__all__ = [
4+
"ndarray",
5+
"Device",
6+
"Dtype",
7+
]
8+
9+
import sys
10+
from typing import (
11+
Literal,
12+
Union,
13+
TYPE_CHECKING,
14+
)
15+
16+
from numpy import (
17+
ndarray,
18+
dtype,
19+
int8,
20+
int16,
21+
int32,
22+
int64,
23+
uint8,
24+
uint16,
25+
uint32,
26+
uint64,
27+
float32,
28+
float64,
29+
)
30+
31+
Device = Literal["cpu"]
32+
if TYPE_CHECKING or sys.version_info >= (3, 9):
33+
Dtype = dtype[Union[
34+
int8,
35+
int16,
36+
int32,
37+
int64,
38+
uint8,
39+
uint16,
40+
uint32,
41+
uint64,
42+
float32,
43+
float64,
44+
]]
45+
else:
46+
Dtype = dtype

0 commit comments

Comments
 (0)