Skip to content

Commit 9df7b54

Browse files
authored
Merge pull request #65 from Quansight-Labs/where_et_al
wrap more functions to complete a minimal viable feature set
2 parents 1ca7255 + d70cab6 commit 9df7b54

13 files changed

+663
-804
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,6 @@ def choose(a, choices, out=None, mode="raise"):
172172
raise NotImplementedError
173173

174174

175-
def clip(a, a_min, a_max, out=None, **kwargs):
176-
raise NotImplementedError
177-
178-
179175
def common_type(*arrays):
180176
raise NotImplementedError
181177

@@ -216,22 +212,6 @@ def deprecate_with_doc(msg):
216212
raise NotImplementedError
217213

218214

219-
def diag_indices(n, ndim=2):
220-
raise NotImplementedError
221-
222-
223-
def diag_indices_from(arr):
224-
raise NotImplementedError
225-
226-
227-
def diagflat(v, k=0):
228-
raise NotImplementedError
229-
230-
231-
def diagonal(a, offset=0, axis1=0, axis2=1):
232-
raise NotImplementedError
233-
234-
235215
def digitize(x, bins, right=False):
236216
raise NotImplementedError
237217

@@ -240,10 +220,6 @@ def disp(mesg, device=None, linefeed=True):
240220
raise NotImplementedError
241221

242222

243-
def dot(a, b, out=None):
244-
raise NotImplementedError
245-
246-
247223
def ediff1d(ary, to_end=None, to_begin=None):
248224
raise NotImplementedError
249225

@@ -260,10 +236,6 @@ def extract(condition, arr):
260236
raise NotImplementedError
261237

262238

263-
def fill_diagonal(a, val, wrap=False):
264-
raise NotImplementedError
265-
266-
267239
def find_common_type(array_types, scalar_types):
268240
raise NotImplementedError
269241

@@ -393,14 +365,6 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
393365
raise NotImplementedError
394366

395367

396-
def indices(dimensions, dtype=int, sparse=False):
397-
raise NotImplementedError
398-
399-
400-
def inner(a, b, /):
401-
raise NotImplementedError
402-
403-
404368
def insert(arr, obj, values, axis=None):
405369
raise NotImplementedError
406370

@@ -417,10 +381,6 @@ def is_busday(dates, weekmask="1111100", holidays=None, busdaycal=None, out=None
417381
raise NotImplementedError
418382

419383

420-
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
421-
raise NotImplementedError
422-
423-
424384
def isfortran(a):
425385
raise NotImplementedError
426386

@@ -437,14 +397,6 @@ def issctype(rep):
437397
raise NotImplementedError
438398

439399

440-
def issubclass_(arg1, arg2):
441-
raise NotImplementedError
442-
443-
444-
def issubdtype(arg1, arg2):
445-
raise NotImplementedError
446-
447-
448400
def issubsctype(arg1, arg2):
449401
raise NotImplementedError
450402

@@ -624,10 +576,6 @@ def obj2sctype(rep, default=None):
624576
raise NotImplementedError
625577

626578

627-
def outer(a, b, out=None):
628-
raise NotImplementedError
629-
630-
631579
def packbits(a, /, axis=None, bitorder="big"):
632580
raise NotImplementedError
633581

@@ -696,10 +644,6 @@ def putmask(a, mask, values):
696644
raise NotImplementedError
697645

698646

699-
def ravel(a, order="C"):
700-
raise NotImplementedError
701-
702-
703647
def recfromcsv(fname, **kwargs):
704648
raise NotImplementedError
705649

@@ -754,10 +698,6 @@ def sctype2char(sctype):
754698
raise NotImplementedError
755699

756700

757-
def searchsorted(a, v, side="left", sorter=None):
758-
raise NotImplementedError
759-
760-
761701
def select(condlist, choicelist, default=0):
762702
raise NotImplementedError
763703

@@ -811,22 +751,10 @@ def show():
811751
raise NotImplementedError
812752

813753

814-
def sinc(x):
815-
raise NotImplementedError
816-
817-
818-
def sometrue(*args, **kwargs):
819-
raise NotImplementedError
820-
821-
822754
def sort_complex(a):
823755
raise NotImplementedError
824756

825757

826-
def swapaxes(a, axis1, axis2):
827-
raise NotImplementedError
828-
829-
830758
def take(a, indices, axis=None, out=None, mode="raise"):
831759
raise NotImplementedError
832760

@@ -835,10 +763,6 @@ def tensordot(a, b, axes=2):
835763
raise NotImplementedError
836764

837765

838-
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
839-
raise NotImplementedError
840-
841-
842766
def trapz(y, x=None, dx=1.0, axis=-1):
843767
raise NotImplementedError
844768

@@ -855,18 +779,6 @@ def union1d(ar1, ar2):
855779
raise NotImplementedError
856780

857781

858-
def unique(
859-
ar,
860-
return_index=False,
861-
return_inverse=False,
862-
return_counts=False,
863-
axis=None,
864-
*,
865-
equal_nan=True,
866-
):
867-
raise NotImplementedError
868-
869-
870782
def unpackbits(a, /, axis=None, count=None, bitorder="big"):
871783
raise NotImplementedError
872784

@@ -875,13 +787,5 @@ def unwrap(p, discont=None, axis=-1, *, period=6.283185307179586):
875787
raise NotImplementedError
876788

877789

878-
def vdot(a, b, /):
879-
raise NotImplementedError
880-
881-
882-
def where(condition, x, y, /):
883-
raise NotImplementedError
884-
885-
886790
def who(vardict=None):
887791
raise NotImplementedError

torch_np/_detail/_flips.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def rot90(m_tensor, k=1, axes=(0, 1)):
2929

3030

3131
def swapaxes(tensor, axis1, axis2):
32+
axis1 = _util.normalize_axis_index(axis1, tensor.ndim)
33+
axis2 = _util.normalize_axis_index(axis2, tensor.ndim)
3234
return torch.swapaxes(tensor, axis1, axis2)
3335

3436

torch_np/_detail/_reductions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,19 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
143143

144144
dtype = _atleast_float(dtype, tensor.dtype)
145145

146+
is_half = dtype == torch.float16
147+
if is_half:
148+
# XXX revisit when the pytorch version has pytorch/pytorch#95166
149+
dtype = torch.float32
150+
146151
if axis is None:
147152
result = tensor.mean(dtype=dtype)
148153
else:
149154
result = tensor.mean(dtype=dtype, dim=axis)
150155

156+
if is_half:
157+
result = result.to(torch.float16)
158+
151159
return result
152160

153161

torch_np/_detail/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class UFuncTypeError(TypeError, RuntimeError):
3636

3737
def cast_if_needed(tensor, dtype):
3838
# NB: no casting if dtype=None
39-
if tensor.dtype != dtype:
39+
if dtype is not None and tensor.dtype != dtype:
4040
tensor = tensor.to(dtype)
4141
return tensor
4242

0 commit comments

Comments
 (0)