Skip to content

Commit 08cc8ed

Browse files
committed
ENH: introduce scalar type hierarchy
Try to duck-type NumPy, only without array scalars, so that * dtype('float32').type --> np.float32 and * np.float32(3) --> array(3.0, dtype='float32) IOW, the attempt is to only have zero-dim arrays (we have them anyway) and have them everywhere where NumPy creates a scalar.
1 parent 5ce3c70 commit 08cc8ed

File tree

3 files changed

+148
-16
lines changed

3 files changed

+148
-16
lines changed

torch_np/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._dtypes import *
2+
from ._scalar_types import *
23
from ._wrapper import *
34
from . import testing
45

torch_np/_dtypes.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import builtins
99
import torch
1010

11+
from . import _scalar_types
12+
1113
# Define analogs of numpy dtypes supported by pytorch.
1214

1315
class dtype:
@@ -22,6 +24,9 @@ def __init__(self, name):
2224
_name = typecode_chars_dict[name]
2325
elif name in dt_aliases_dict:
2426
_name = dt_aliases_dict[name]
27+
# the check must come last, so that 'name' is not a string
28+
elif issubclass(name, _scalar_types.generic):
29+
_name = name.name
2530
else:
2631
raise TypeError(f"data type '{name}' not understood")
2732
self._name = _name
@@ -30,6 +35,10 @@ def __init__(self, name):
3035
def name(self):
3136
return self._name
3237

38+
@property
39+
def type(self):
40+
return _scalar_types._typemap[self._name]
41+
3342
@property
3443
def typecode(self):
3544
return _typecodes_from_dtype_dict[self._name]
@@ -104,21 +113,6 @@ def __repr__(self):
104113
}
105114

106115

107-
float16 = dtype("float16")
108-
float32 = dtype("float32")
109-
float64 = dtype("float64")
110-
complex64 = dtype("complex64")
111-
complex128 = dtype("complex128")
112-
uint8 = dtype("uint8")
113-
int8 = dtype("int8")
114-
int16 = dtype("int16")
115-
int32 = dtype("int32")
116-
int64 = dtype("int64")
117-
bool = dtype("bool")
118-
119-
intp = int64 # XXX
120-
int_ = int64
121-
122116
# Map the torch-suppored subset dtypes to local analogs
123117
# "quantized" types not available in numpy, skip
124118
_dtype_from_torch_dict = {
@@ -217,5 +211,5 @@ def is_integer(dtyp):
217211
########################## end autogenerated part
218212

219213

220-
__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] + dt_names + ['intp', 'int_']
214+
__all__ = ['dtype_from_torch', 'dtype', 'typecodes']
221215

torch_np/_scalar_types.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Replicate the NumPy scalar type hierarchy
2+
"""
3+
4+
import abc
5+
import torch
6+
7+
class generic(abc.ABC):
8+
@property
9+
@abc.abstractmethod
10+
def name(self):
11+
return self.__class__.__name__
12+
13+
def __new__(self, value):
14+
#
15+
# Yes, a call to np.float32(4) produces a zero-dim array.
16+
#
17+
from . import _dtypes
18+
from . import _ndarray
19+
20+
torch_dtype = _dtypes.torch_dtype_from(self.name)
21+
tensor = torch.as_tensor(value, dtype=torch_dtype)
22+
return _ndarray.ndarray._from_tensor_and_base(tensor, None)
23+
24+
25+
##### these are abstract types
26+
27+
class number(generic):
28+
pass
29+
30+
31+
class integer(generic):
32+
pass
33+
34+
35+
class inexact(generic):
36+
pass
37+
38+
39+
class signedinteger(generic):
40+
pass
41+
42+
43+
class unsignedinteger(generic):
44+
pass
45+
46+
47+
class inexact(generic):
48+
pass
49+
50+
51+
class floating(generic):
52+
pass
53+
54+
55+
class complexfloating(generic):
56+
pass
57+
58+
59+
# ##### concrete types
60+
61+
# signed integers
62+
63+
class int8(signedinteger):
64+
name = 'int8'
65+
66+
67+
class int16(signedinteger):
68+
name = 'int16'
69+
70+
71+
class int32(signedinteger):
72+
name = 'int32'
73+
74+
75+
class int64(signedinteger):
76+
name = 'int64'
77+
78+
79+
# unsigned integers
80+
81+
class uint8(unsignedinteger):
82+
name = 'uint8'
83+
84+
85+
# floating point
86+
87+
class float16(floating):
88+
name = 'float16'
89+
90+
91+
class float32(floating):
92+
name = 'float32'
93+
94+
95+
class float64(floating):
96+
name = 'float64'
97+
98+
99+
class complex64(complexfloating):
100+
name = 'complex64'
101+
102+
103+
class complex128(complexfloating):
104+
name = 'complex128'
105+
106+
107+
class bool_(generic):
108+
name = 'bool'
109+
110+
111+
# name aliases : FIXME (OS, bitness)
112+
intp = int64
113+
int_ = int64
114+
115+
116+
_typemap ={
117+
'int8' : int8,
118+
'int16' : int16,
119+
'int32' : int32,
120+
'int64' : int64,
121+
'uint8' : uint8,
122+
'float16': float16,
123+
'float32': float32,
124+
'float64': float64,
125+
'complex64': complex64,
126+
'complex128': complex128,
127+
'bool': bool_
128+
}
129+
130+
131+
__all__ = list(_typemap.keys())
132+
__all__.remove('bool')
133+
134+
__all__ += ['bool_', 'intp', 'int_']
135+
__all__ += ['generic', 'number',
136+
'signedinteger', 'unsignedinteger',
137+
'inexact', 'floating', 'complexfloating']

0 commit comments

Comments
 (0)