|
6 | 6 | # Contents of this module ends up in the main namespace via _funcs.py
|
7 | 7 | # where type annotations are used in conjunction with the @normalizer decorator.
|
8 | 8 |
|
9 |
| - |
| 9 | +import operator |
10 | 10 | from typing import Optional, Sequence
|
11 | 11 |
|
12 | 12 | import torch
|
@@ -36,6 +36,14 @@ def copy(a: ArrayLike, order="K", subok: SubokLike = False):
|
36 | 36 | return a.clone()
|
37 | 37 |
|
38 | 38 |
|
| 39 | +def copyto(dst: NDArray, src: ArrayLike, casting="same_kind", where=NoValue): |
| 40 | + if where is not NoValue: |
| 41 | + raise NotImplementedError |
| 42 | + (src,) = _util.typecast_tensors((src,), dst.tensor.dtype, casting=casting) |
| 43 | + dst.tensor.copy_(src) |
| 44 | + |
| 45 | + |
| 46 | + |
39 | 47 | def atleast_1d(*arys: ArrayLike):
|
40 | 48 | res = torch.atleast_1d(*arys)
|
41 | 49 | if isinstance(res, tuple):
|
@@ -1811,3 +1819,49 @@ def common_type(*tensors: ArrayLike):
|
1811 | 1819 | return array_type[1][precision]
|
1812 | 1820 | else:
|
1813 | 1821 | return array_type[0][precision]
|
| 1822 | + |
| 1823 | + |
| 1824 | + |
| 1825 | +# ### histograms ### |
| 1826 | + |
| 1827 | + |
| 1828 | +def histogram( |
| 1829 | + a: ArrayLike, |
| 1830 | + bins: ArrayLike = 10, |
| 1831 | + range=None, |
| 1832 | + normed=None, |
| 1833 | + weights: Optional[ArrayLike] = None, |
| 1834 | + density=None, |
| 1835 | +): |
| 1836 | + if normed is not None: |
| 1837 | + raise ValueError("normed argument is deprecated, use density= instead") |
| 1838 | + |
| 1839 | + is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex) |
| 1840 | + is_w_int = weights is None or not weights.dtype.is_floating_point |
| 1841 | + if is_a_int: |
| 1842 | + a = a.to(float) |
| 1843 | + |
| 1844 | + if weights is not None: |
| 1845 | + weights = _util.cast_if_needed(weights, a.dtype) |
| 1846 | + |
| 1847 | + if isinstance(bins, torch.Tensor): |
| 1848 | + if bins.ndim == 0: |
| 1849 | + # bins was a single int |
| 1850 | + bins = operator.index(bins) |
| 1851 | + else: |
| 1852 | + bins = _util.cast_if_needed(bins, a.dtype) |
| 1853 | + |
| 1854 | + if range is None: |
| 1855 | + h, b = torch.histogram(a, bins, weight=weights, density=bool(density)) |
| 1856 | + else: |
| 1857 | + h, b = torch.histogram( |
| 1858 | + a, bins, range=range, weight=weights, density=bool(density) |
| 1859 | + ) |
| 1860 | + |
| 1861 | + if not density and is_w_int: |
| 1862 | + h = h.to(int) |
| 1863 | + if is_a_int: |
| 1864 | + b = b.to(int) |
| 1865 | + |
| 1866 | + return h, b |
| 1867 | + |
0 commit comments