Skip to content

Commit bc75eb2

Browse files
authored
Merge pull request #12 from Quansight-Labs/dtypes2
dtypes and arrays scalars
2 parents 5ce3c70 + 2830ada commit bc75eb2

19 files changed

+491
-653
lines changed

.gitignore

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
__pycache__/*
2-
autogen/__pycache__
3-
torch_np/__pycache__/*
4-
torch_np/tests/__pycache__/*
5-
torch_np/tests/numpy_tests/core/__pycache__/*
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
64
.coverage
75

torch_np/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from ._dtypes import *
2+
from ._scalar_types import *
23
from ._wrapper import *
3-
from . import testing
4+
#from . import testing
45

56
from ._unary_ufuncs import *
67
from ._binary_ufuncs import *
78
from ._ndarray import can_cast, result_type, newaxis
89
from ._util import AxisError
9-
10+
from ._getlimits import iinfo, finfo
11+
from ._getlimits import errstate
1012

1113
inf = float('inf')
1214
nan = float('nan')
15+

torch_np/_dtypes.py

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@
88
import builtins
99
import torch
1010

11+
from . import _scalar_types
12+
13+
14+
__all__ = ['dtype_from_torch', 'dtype', 'typecodes', 'issubdtype']
15+
16+
1117
# Define analogs of numpy dtypes supported by pytorch.
1218

1319
class dtype:
14-
def __init__(self, name):
20+
def __init__(self, name, /):
1521
if isinstance(name, dtype):
1622
_name = name.name
23+
elif hasattr(name, 'dtype'):
24+
_name = name.dtype.name
1725
elif name in python_types_dict:
1826
_name = python_types_dict[name]
1927
elif name in dt_names:
@@ -22,6 +30,9 @@ def __init__(self, name):
2230
_name = typecode_chars_dict[name]
2331
elif name in dt_aliases_dict:
2432
_name = dt_aliases_dict[name]
33+
# the check must come last, so that 'name' is not a string
34+
elif issubclass(name, _scalar_types.generic):
35+
_name = name.name
2536
else:
2637
raise TypeError(f"data type '{name}' not understood")
2738
self._name = _name
@@ -30,6 +41,10 @@ def __init__(self, name):
3041
def name(self):
3142
return self._name
3243

44+
@property
45+
def type(self):
46+
return _scalar_types._typemap[self._name]
47+
3348
@property
3449
def typecode(self):
3550
return _typecodes_from_dtype_dict[self._name]
@@ -38,14 +53,31 @@ def __eq__(self, other):
3853
if isinstance(other, dtype):
3954
return self._name == other.name
4055
else:
41-
other_instance = dtype(other)
56+
try:
57+
other_instance = dtype(other)
58+
except TypeError:
59+
return False
4260
return self._name == other_instance.name
4361

62+
def __hash__(self):
63+
return hash(self._name)
64+
4465
def __repr__(self):
4566
return f'dtype("{self.name}")'
4667

4768
__str__ = __repr__
4869

70+
def itemsize(self):
71+
elem = self.type(1)
72+
return elem.get().element_size()
73+
74+
def __getstate__(self):
75+
return self._name
76+
77+
def __setstate__(self, value):
78+
self._name = value
79+
80+
4981

5082
dt_names = ['float16', 'float32', 'float64',
5183
'complex64', 'complex128',
@@ -58,6 +90,7 @@ def __repr__(self):
5890

5991

6092
dt_aliases_dict = {
93+
'u1' : 'uint8',
6194
'i1' : 'int8',
6295
'i2' : 'int16',
6396
'i4' : 'int32',
@@ -75,7 +108,12 @@ def __repr__(self):
75108
python_types_dict = {
76109
int: 'int64',
77110
float: 'float64',
78-
builtins.bool: 'bool'
111+
complex: 'complex128',
112+
builtins.bool: 'bool',
113+
# also allow stringified names of python types
114+
int.__name__ : 'int64',
115+
float.__name__ : 'float64',
116+
complex.__name__: 'complex128',
79117
}
80118

81119

@@ -101,24 +139,13 @@ def __repr__(self):
101139
typecodes = {'All': 'efdFDBbhil?',
102140
'AllFloat': 'efdFD',
103141
'AllInteger': 'Bbhil',
142+
'Integer': 'bhil',
143+
'UnsignedInteger': 'B',
144+
'Float': 'efd',
145+
'Complex': 'FD',
104146
}
105147

106148

107-
float16 = dtype("float16")
108-
float32 = dtype("float32")
109-
float64 = dtype("float64")
110-
complex64 = dtype("complex64")
111-
complex128 = dtype("complex128")
112-
uint8 = dtype("uint8")
113-
int8 = dtype("int8")
114-
int16 = dtype("int16")
115-
int32 = dtype("int32")
116-
int64 = dtype("int64")
117-
bool = dtype("bool")
118-
119-
intp = int64 # XXX
120-
int_ = int64
121-
122149
# Map the torch-suppored subset dtypes to local analogs
123150
# "quantized" types not available in numpy, skip
124151
_dtype_from_torch_dict = {
@@ -183,6 +210,23 @@ def is_integer(dtyp):
183210
return dtyp.typecode in typecodes['AllInteger']
184211

185212

213+
214+
def issubclass_(arg, klass):
215+
try:
216+
return issubclass(arg, klass)
217+
except TypeError:
218+
return False
219+
220+
221+
def issubdtype(arg1, arg2):
222+
# cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420
223+
if not issubclass_(arg1, _scalar_types.generic):
224+
arg1 = dtype(arg1).type
225+
if not issubclass_(arg2, _scalar_types.generic):
226+
arg2 = dtype(arg2).type
227+
return issubclass(arg1, arg2)
228+
229+
186230
# The casting below is defined *with dtypes only*, so no value-based casting!
187231

188232
# These two dicts are autogenerated with autogen/gen_dtypes.py,
@@ -216,6 +260,3 @@ def is_integer(dtyp):
216260

217261
########################## end autogenerated part
218262

219-
220-
__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] + dt_names + ['intp', 'int_']
221-

torch_np/_getlimits.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
from . import _dtypes
3+
4+
def finfo(dtyp):
5+
torch_dtype = _dtypes.torch_dtype_from(dtyp)
6+
return torch.finfo(torch_dtype)
7+
8+
9+
def iinfo(dtyp):
10+
torch_dtype = _dtypes.torch_dtype_from(dtyp)
11+
return torch.iinfo(torch_dtype)
12+
13+
14+
import contextlib
15+
16+
# FIXME: this is only a stub
17+
@contextlib.contextmanager
18+
def errstate(*args, **kwds):
19+
yield

torch_np/_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,9 @@ def allow_only_single_axis(axis):
125125
if len(axis) != 1:
126126
raise NotImplementedError("does not handle tuple axis")
127127
return axis[0]
128+
129+
130+
def to_tensors(*inputs):
131+
"""Convert all ndarrays from `inputs` to tensors."""
132+
return tuple([value.get() if isinstance(value, ndarray) else value
133+
for value in inputs])

torch_np/_ndarray.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ def __neq__(self, other):
110110
def __gt__(self, other):
111111
return asarray(self._tensor > asarray(other).get())
112112

113+
def __lt__(self, other):
114+
return asarray(self._tensor < asarray(other).get())
115+
116+
def __ge__(self, other):
117+
return asarray(self._tensor >= asarray(other).get())
118+
119+
def __le__(self, other):
120+
return asarray(self._tensor <= asarray(other).get())
121+
122+
113123
def __bool__(self):
114124
try:
115125
return bool(self._tensor)
@@ -131,6 +141,15 @@ def __hash__(self):
131141
def __float__(self):
132142
return float(self._tensor)
133143

144+
# XXX : are single-element ndarrays scalars?
145+
def is_integer(self):
146+
if self.shape == ():
147+
if _dtypes.is_integer(self.dtype):
148+
return True
149+
return self._tensor.item().is_integer()
150+
else:
151+
return False
152+
134153

135154
### sequence ###
136155
def __len__(self):
@@ -162,6 +181,15 @@ def __truediv__(self, other):
162181
other_tensor = asarray(other).get()
163182
return asarray(self._tensor.__truediv__(other_tensor))
164183

184+
def __or__(self, other):
185+
other_tensor = asarray(other).get()
186+
return asarray(self._tensor.__or__(other_tensor))
187+
188+
def __ior__(self, other):
189+
other_tensor = asarray(other).get()
190+
return asarray(self._tensor.__ior__(other_tensor))
191+
192+
165193
def __invert__(self):
166194
return asarray(self._tensor.__invert__())
167195

@@ -307,7 +335,8 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,
307335

308336
### indexing ###
309337
def __getitem__(self, *args, **kwds):
310-
return ndarray._from_tensor_and_base(self._tensor.__getitem__(*args, **kwds), self)
338+
t_args = _helpers.to_tensors(*args)
339+
return ndarray._from_tensor_and_base(self._tensor.__getitem__(*t_args, **kwds), self)
311340

312341
def __setitem__(self, index, value):
313342
value = asarray(value).get()
@@ -320,6 +349,8 @@ def asarray(a, dtype=None, order=None, *, like=None):
320349
raise NotImplementedError
321350

322351
if isinstance(a, ndarray):
352+
if dtype is not None and dtype != a.dtype:
353+
a = a.astype(dtype)
323354
return a
324355

325356
if isinstance(a, (list, tuple)):
@@ -356,6 +387,10 @@ def array(object, dtype=None, *, copy=True, order='K', subok=False, ndmin=0,
356387

357388
if isinstance(object, ndarray):
358389
result = object._tensor
390+
391+
if dtype != object.dtype:
392+
torch_dtype = _dtypes.torch_dtype_from(dtype)
393+
result = result.to(torch_dtype)
359394
else:
360395
torch_dtype = _dtypes.torch_dtype_from(dtype)
361396
result = torch.as_tensor(object, dtype=torch_dtype)

0 commit comments

Comments
 (0)