Skip to content

Commit 69be71f

Browse files
committed
MAINT: add prod, var, std to follow sum
1 parent d92b127 commit 69be71f

File tree

4 files changed

+96
-50
lines changed

4 files changed

+96
-50
lines changed

torch_np/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,13 @@ def to_tensors(*inputs):
131131
"""Convert all ndarrays from `inputs` to tensors."""
132132
return tuple([value.get() if isinstance(value, ndarray) else value
133133
for value in inputs])
134+
135+
136+
def float_or_default(dtype, self_dtype):
137+
"""dtype helper for reductions."""
138+
if dtype is None:
139+
dtype = self_dtype
140+
if _dtypes.is_integer(dtype):
141+
dtype = _dtypes.default_float_type()
142+
torch_dtype = _dtypes.torch_dtype_from(dtype)
143+
return torch_dtype

torch_np/_ndarray.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,7 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
416416
if where is not None:
417417
raise NotImplementedError
418418

419-
if dtype is None:
420-
dtype = self.dtype
421-
if _dtypes.is_integer(dtype):
422-
dtype = _dtypes.default_float_type()
423-
torch_dtype = _dtypes.torch_dtype_from(dtype)
424-
419+
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
425420
if axis is None:
426421
result = self._tensor.mean(dtype=torch_dtype)
427422
else:
@@ -436,19 +431,57 @@ def sum(self, axis=None, dtype=None, out=None, keepdims=NoValue,
436431
if initial is not None or where is not None:
437432
raise NotImplementedError
438433

439-
if dtype is None:
440-
dtype = self.dtype
441-
if _dtypes.is_integer(dtype):
442-
dtype = _dtypes.default_float_type()
443-
torch_dtype = _dtypes.torch_dtype_from(dtype)
444-
434+
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
445435
if axis is None:
446436
result = self._tensor.sum(dtype=torch_dtype)
447437
else:
448438
result = self._tensor.sum(dtype=torch_dtype, dim=axis)
449439

450440
return result
451441

442+
@axis_out_keepdims_wrapper
443+
def prod(self, axis=None, dtype=None, out=None, keepdims=NoValue,
444+
initial=NoValue, where=NoValue):
445+
if initial is not None or where is not None:
446+
raise NotImplementedError
447+
448+
axis = _helpers.allow_only_single_axis(axis)
449+
450+
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
451+
if axis is None:
452+
result = self._tensor.prod(dtype=torch_dtype)
453+
else:
454+
result = self._tensor.prod(dtype=torch_dtype, dim=axis)
455+
456+
return result
457+
458+
459+
@axis_out_keepdims_wrapper
460+
def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *,
461+
where=NoValue):
462+
if where is not None:
463+
raise NotImplementedError
464+
465+
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
466+
tensor = self._tensor.to(torch_dtype) # XXX: needed?
467+
468+
result = tensor.std(dim=axis, correction=ddof)
469+
470+
return result
471+
472+
@axis_out_keepdims_wrapper
473+
def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *,
474+
where=NoValue):
475+
if where is not None:
476+
raise NotImplementedError
477+
478+
torch_dtype = _helpers.float_or_default(dtype, self.dtype)
479+
tensor = self._tensor.to(torch_dtype) # XXX: needed?
480+
481+
result = tensor.var(dim=axis, correction=ddof)
482+
483+
return result
484+
452485

453486
### indexing ###
454487
def __getitem__(self, *args, **kwds):

torch_np/_wrapper.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -307,20 +307,6 @@ def identity(n, dtype=None, *, like=None):
307307
###### misc/unordered
308308

309309

310-
#YYY: pattern: initial=...
311-
@asarray_replacer()
312-
def prod(a, axis=None, dtype=None, out=None, keepdims=NoValue,
313-
initial=NoValue, where=NoValue):
314-
if initial is not None or where is not None:
315-
raise NotImplementedError
316-
if axis is None:
317-
if keepdims is not None:
318-
raise NotImplementedError
319-
return torch.prod(a, dtype=dtype)
320-
elif _util.is_sequence(axis):
321-
raise NotImplementedError
322-
return torch.prod(a, dim=axis, dtype=dtype, keepdim=bool(keepdims), out=out)
323-
324310

325311

326312
@asarray_replacer()
@@ -639,13 +625,33 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue)
639625
return arr.mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where)
640626

641627

628+
#YYY: pattern: initial=...
629+
642630
def sum(a, axis=None, dtype=None, out=None, keepdims=NoValue,
643631
initial=NoValue, where=NoValue):
644632
arr = asarray(a)
645633
return arr.sum(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
646634
initial=initial, where=where)
647635

648636

637+
def prod(a, axis=None, dtype=None, out=None, keepdims=NoValue,
638+
initial=NoValue, where=NoValue):
639+
arr = asarray(a)
640+
return arr.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims,
641+
initial=initial, where=where)
642+
643+
644+
#YYY: pattern : ddof
645+
646+
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
647+
arr = asarray(a)
648+
return arr.std(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where)
649+
650+
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
651+
arr = asarray(a)
652+
return arr.var(axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, where=where)
653+
654+
649655
@asarray_replacer()
650656
def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue):
651657
if where is not None:
@@ -663,30 +669,6 @@ def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoVal
663669
return result
664670

665671

666-
# YYY: pattern : std, var
667-
@asarray_replacer()
668-
def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
669-
if where is not None:
670-
raise NotImplementedError
671-
if dtype is not None:
672-
raise NotImplementedError
673-
if not torch.is_floating_point(a):
674-
a = a * 1.0
675-
return torch.std(a, axis, correction=ddof, keepdim=bool(keepdims), out=out)
676-
677-
678-
@asarray_replacer()
679-
def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=NoValue, *, where=NoValue):
680-
if where is not None:
681-
raise NotImplementedError
682-
if dtype is not None:
683-
raise NotImplementedError
684-
if not torch.is_floating_point(a):
685-
a = a * 1.0
686-
return torch.var(a, axis, correction=ddof, keepdim=bool(keepdims), out=out)
687-
688-
689-
690672
@asarray_replacer()
691673
def argsort(a, axis=-1, kind=None, order=None):
692674
if order is not None:

torch_np/tests/test_reductions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,3 +593,24 @@ def setup_method(self):
593593
self.allowed_axes = [0, 1, 2, -1, -2,
594594
(0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
595595

596+
597+
class TestProdGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin):
598+
def setup_method(self):
599+
self.func = np.prod
600+
self.allowed_axes = [0, 1, 2, -1, -2,]
601+
# (0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
602+
603+
604+
class TestStdGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin):
605+
def setup_method(self):
606+
self.func = np.std
607+
self.allowed_axes = [0, 1, 2, -1, -2,
608+
(0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
609+
610+
611+
class TestVarGeneric(_GenericReductionsTestMixin, _GenericHasOutTestMixin):
612+
def setup_method(self):
613+
self.func = np.var
614+
self.allowed_axes = [0, 1, 2, -1, -2,
615+
(0, 1), (1, 0), (0, 1, 2), (1, -1, 0)]
616+

0 commit comments

Comments
 (0)