Skip to content

Commit d160883

Browse files
authored
Merge pull request #8 from Quansight-Labs/reductions
Implement reduction operations
2 parents a68d8d4 + acb032c commit d160883

13 files changed

+1472
-134
lines changed

autogen/numpy_api_dump.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,13 @@ def seterrobj(errobj):
5858

5959

6060

61-
def alen(a):
62-
raise NotImplementedError
63-
6461

6562
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
6663
raise NotImplementedError
6764

6865
def alltrue(*args, **kwargs):
6966
raise NotImplementedError
7067

71-
def amax(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue):
72-
raise NotImplementedError
73-
74-
def amin(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue):
75-
raise NotImplementedError
76-
77-
78-
79-
8068
def append(arr, values, axis=None):
8169
raise NotImplementedError
8270

@@ -95,10 +83,6 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
9583
def argsort(a, axis=-1, kind=None, order=None):
9684
raise NotImplementedError
9785

98-
def argwhere(a):
99-
raise NotImplementedError
100-
101-
10286
def array2string(a, max_line_width=None, precision=None, suppress_small=None, separator=' ', prefix='', style=NoValue, formatter=None, threshold=None, edgeitems=None, sign=None, floatmode=None, suffix='', *, legacy=None):
10387
raise NotImplementedError
10488

@@ -126,9 +110,6 @@ def asfarray(a, dtype='numpy.float64'):
126110
def asmatrix(data, dtype=None):
127111
raise NotImplementedError
128112

129-
def asscalar(a):
130-
raise NotImplementedError
131-
132113

133114
def average(a, axis=None, weights=None, returned=False, *, keepdims=NoValue):
134115
raise NotImplementedError
@@ -189,8 +170,6 @@ def copyto(dst, src, casting='same_kind', where=True):
189170
def correlate(a, v, mode='valid'):
190171
raise NotImplementedError
191172

192-
def count_nonzero(a, axis=None, *, keepdims=False):
193-
raise NotImplementedError
194173

195174
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, *, dtype=None):
196175
raise NotImplementedError
@@ -277,9 +256,6 @@ def find_common_type(array_types, scalar_types):
277256
def fix(x, out=None):
278257
raise NotImplementedError
279258

280-
def flatnonzero(a):
281-
raise NotImplementedError
282-
283259
def flip(m, axis=None):
284260
raise NotImplementedError
285261

@@ -444,8 +420,7 @@ def mask_indices(n, mask_func, k=0):
444420
def asmatrix(data, dtype=None):
445421
raise NotImplementedError
446422

447-
def amax(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue):
448-
raise NotImplementedError
423+
449424

450425
def maximum_sctype(t):
451426
raise NotImplementedError
@@ -461,8 +436,7 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
461436
def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
462437
raise NotImplementedError
463438

464-
def amin(a, axis=None, out=None, keepdims=NoValue, initial=NoValue, where=NoValue):
465-
raise NotImplementedError
439+
466440

467441
def min_scalar_type(a, /):
468442
raise NotImplementedError

torch_np/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66
from ._binary_ufuncs import *
77
from ._ndarray import can_cast, result_type, newaxis
88
from ._util import AxisError
9+
10+
11+
inf = float('inf')
12+
nan = float('nan')

torch_np/_dtypes.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def __init__(self, name):
1818
_name = python_types_dict[name]
1919
elif name in dt_names:
2020
_name = name
21+
elif name in typecode_chars_dict:
22+
_name = typecode_chars_dict[name]
2123
elif name in dt_aliases_dict:
2224
_name = dt_aliases_dict[name]
2325
else:
@@ -28,6 +30,10 @@ def __init__(self, name):
2830
def name(self):
2931
return self._name
3032

33+
@property
34+
def typecode(self):
35+
return _typecodes_from_dtype_dict[self._name]
36+
3137
def __eq__(self, other):
3238
if isinstance(other, dtype):
3339
return self._name == other.name
@@ -62,6 +68,7 @@ def __repr__(self):
6268
'f8' : 'float64',
6369
'c8' : 'complex64',
6470
'c16': 'complex128',
71+
'?' : 'bool',
6572
}
6673

6774

@@ -72,6 +79,31 @@ def __repr__(self):
7279
}
7380

7481

82+
typecode_chars_dict = {
83+
'e': 'float16',
84+
'f': 'float32',
85+
'd': 'float64',
86+
'F': 'complex64',
87+
'D': 'complex128',
88+
'B': 'uint8',
89+
'b': 'int8',
90+
'h': 'int16',
91+
'i': 'int32',
92+
'l': 'int64',
93+
'?': 'bool'
94+
}
95+
96+
# reverse mapping
97+
_typecodes_from_dtype_dict = {typecode_chars_dict[key]: key
98+
for key in typecode_chars_dict}
99+
100+
101+
typecodes = {'All': 'efdFDBbhil?',
102+
'AllFloat': 'efdFD',
103+
'AllInteger': 'Bbhil',
104+
}
105+
106+
75107
float16 = dtype("float16")
76108
float32 = dtype("float32")
77109
float64 = dtype("float64")
@@ -84,6 +116,8 @@ def __repr__(self):
84116
int64 = dtype("int64")
85117
bool = dtype("bool")
86118

119+
intp = int64 # XXX
120+
int_ = int64
87121

88122
# Map the torch-suppored subset dtypes to local analogs
89123
# "quantized" types not available in numpy, skip
@@ -132,6 +166,22 @@ def torch_dtype_from(dtyp):
132166
raise TypeError
133167

134168

169+
def default_int_type():
170+
return dtype('int64')
171+
172+
173+
def default_float_type():
174+
return dtype('float64')
175+
176+
177+
def is_floating(dtyp):
178+
dtyp = dtype(dtyp)
179+
return dtyp.typecode in typecodes['AllFloat']
180+
181+
def is_integer(dtyp):
182+
dtyp = dtype(dtyp)
183+
return dtyp.typecode in typecodes['AllInteger']
184+
135185

136186
# The casting below is defined *with dtypes only*, so no value-based casting!
137187

@@ -167,5 +217,5 @@ def torch_dtype_from(dtyp):
167217
########################## end autogenerated part
168218

169219

170-
__all__ = ['dtype_from_torch', 'dtype'] + dt_names
220+
__all__ = ['dtype_from_torch', 'dtype', 'typecodes'] + dt_names + ['intp', 'int_']
171221

torch_np/_helpers.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import operator
2+
13
import torch
24
from . import _dtypes
35
from ._ndarray import can_cast, ndarray, asarray
6+
from . import _util
47

58
def cast_and_broadcast(arrays, out, casting):
69
"""Cast dtypes of arrays to out.dtype and broadcast if needed.
@@ -82,14 +85,43 @@ def axis_none_ravel(*arrays, axis=None):
8285
return arrays, axis
8386

8487

85-
def result_or_out(result, out=None):
88+
def result_or_out(result_tensor, out_array=None):
8689
"""A helper for returns with out= argument."""
87-
if out is not None:
88-
if result.shape != out.shape:
90+
if out_array is not None:
91+
if result_tensor.shape != out_array.shape:
8992
raise ValueError
90-
out_tensor = out.get()
91-
out_tensor.copy_(result)
92-
return out
93+
out_tensor = out_array.get()
94+
out_tensor.copy_(result_tensor)
95+
return out_array
96+
else:
97+
return asarray(result_tensor)
98+
99+
100+
def apply_keepdims(tensor, axis, ndim):
101+
if axis is None:
102+
# tensor was a scalar
103+
tensor = torch.full((1,)*ndim, fill_value=tensor)
93104
else:
94-
return asarray(result)
105+
shape = _util.expand_shape(tensor.shape, axis)
106+
tensor = tensor.reshape(shape)
107+
return tensor
95108

109+
110+
def standardize_axis_arg(axis, ndim):
111+
"""Return axis as either None or a tuple of normalized axes."""
112+
if isinstance(axis, ndarray):
113+
axis = operator.index(axis)
114+
115+
if axis is not None:
116+
if not isinstance(axis, (list, tuple)):
117+
axis = (axis,)
118+
axis = _util.normalize_axis_tuple(axis, ndim)
119+
return axis
120+
121+
122+
def allow_only_single_axis(axis):
123+
if axis is None:
124+
return axis
125+
if len(axis) != 1:
126+
raise NotImplementedError("does not handle tuple axis")
127+
return axis[0]

0 commit comments

Comments
 (0)