Skip to content

Commit 25bfad9

Browse files
committed
ENH: add copyto, histogram
1 parent d8f1461 commit 25bfad9

File tree

2 files changed

+854
-1
lines changed

2 files changed

+854
-1
lines changed

torch_np/_funcs_impl.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Contents of this module ends up in the main namespace via _funcs.py
77
# where type annotations are used in conjunction with the @normalizer decorator.
88

9-
9+
import operator
1010
from typing import Optional, Sequence
1111

1212
import torch
@@ -36,6 +36,14 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
3636
return a.clone()
3737

3838

39+
def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue):
40+
if where is not NoValue:
41+
raise NotImplementedError
42+
(src,) = _util.typecast_tensors((src,), dst.tensor.dtype, casting=casting)
43+
dst.tensor.copy_(src)
44+
45+
46+
3947
def atleast_1d(*arys: ArrayLike):
4048
res = torch.atleast_1d(*arys)
4149
if isinstance(res, tuple):
@@ -1811,3 +1819,49 @@ def common_type(*tensors: ArrayLike):
18111819
return array_type[1][precision]
18121820
else:
18131821
return array_type[0][precision]
1822+
1823+
1824+
1825+
# ### histograms ###
1826+
1827+
1828+
def histogram(
1829+
a: ArrayLike,
1830+
bins: ArrayLike = 10,
1831+
range=None,
1832+
normed=None,
1833+
weights: Optional[ArrayLike] = None,
1834+
density=None,
1835+
):
1836+
if normed is not None:
1837+
raise ValueError("normed argument is deprecated, use density= instead")
1838+
1839+
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
1840+
is_w_int = weights is None or not weights.dtype.is_floating_point
1841+
if is_a_int:
1842+
a = a.to(float)
1843+
1844+
if weights is not None:
1845+
weights = _util.cast_if_needed(weights, a.dtype)
1846+
1847+
if isinstance(bins, torch.Tensor):
1848+
if bins.ndim == 0:
1849+
# bins was a single int
1850+
bins = operator.index(bins)
1851+
else:
1852+
bins = _util.cast_if_needed(bins, a.dtype)
1853+
1854+
if range is None:
1855+
h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
1856+
else:
1857+
h, b = torch.histogram(
1858+
a, bins, range=range, weight=weights, density=bool(density)
1859+
)
1860+
1861+
if not density and is_w_int:
1862+
h = h.to(int)
1863+
if is_a_int:
1864+
b = b.to(int)
1865+
1866+
return h, b
1867+

0 commit comments

Comments
 (0)