diff --git a/autogen/numpy_api_dump.py b/autogen/numpy_api_dump.py index 4ac61f12..450fef1b 100644 --- a/autogen/numpy_api_dump.py +++ b/autogen/numpy_api_dump.py @@ -134,10 +134,6 @@ def binary_repr(num, width=None): raise NotImplementedError -def blackman(M): - raise NotImplementedError - - def block(arrays): raise NotImplementedError @@ -337,14 +333,6 @@ def gradient(f, *varargs, axis=None, edge_order=1): raise NotImplementedError -def hamming(M): - raise NotImplementedError - - -def hanning(M): - raise NotImplementedError - - def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): raise NotImplementedError @@ -409,10 +397,6 @@ def ix_(*args): raise NotImplementedError -def kaiser(M, beta): - raise NotImplementedError - - def lexsort(keys, axis=-1): raise NotImplementedError @@ -759,10 +743,6 @@ def take(a, indices, axis=None, out=None, mode="raise"): raise NotImplementedError -def tensordot(a, b, axes=2): - raise NotImplementedError - - def trapz(y, x=None, dx=1.0, axis=-1): raise NotImplementedError diff --git a/torch_np/__init__.py b/torch_np/__init__.py index 3e377e9a..d2b3f539 100644 --- a/torch_np/__init__.py +++ b/torch_np/__init__.py @@ -14,7 +14,7 @@ inf = float("inf") nan = float("nan") -from math import pi # isort: skip +from math import pi, e # isort: skip False_ = asarray(False, bool_) True_ = asarray(True, bool_) diff --git a/torch_np/_funcs.py b/torch_np/_funcs.py index 5396a310..115d7c64 100644 --- a/torch_np/_funcs.py +++ b/torch_np/_funcs.py @@ -1158,6 +1158,13 @@ def vdot(a: ArrayLike, b: ArrayLike, /): return result.item() +@normalizer +def tensordot(a: ArrayLike, b: ArrayLike, axes=2): + if isinstance(axes, (list, tuple)): + axes = [[ax] if isinstance(ax, int) else ax for ax in axes] + return torch.tensordot(a, b, dims=axes) + + @normalizer def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None): dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype)) @@ -1850,3 +1857,36 @@ def __getitem__(self, item): index_exp = IndexExpression(maketuple=True) s_ = IndexExpression(maketuple=False) + + +# ### Filter windows ### + + +@normalizer +def hamming(M): + dtype = _dtypes_impl.default_float_dtype + return torch.hamming_window(M, periodic=False, dtype=dtype) + + +@normalizer +def hanning(M): + dtype = _dtypes_impl.default_float_dtype + return torch.hann_window(M, periodic=False, dtype=dtype) + + +@normalizer +def kaiser(M, beta): + dtype = _dtypes_impl.default_float_dtype + return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype) + + +@normalizer +def blackman(M): + dtype = _dtypes_impl.default_float_dtype + return torch.blackman_window(M, periodic=False, dtype=dtype) + + +@normalizer +def bartlett(M): + dtype = _dtypes_impl.default_float_dtype + return torch.bartlett_window(M, periodic=False, dtype=dtype) diff --git a/torch_np/_ndarray.py b/torch_np/_ndarray.py index c5643705..3a02e277 100644 --- a/torch_np/_ndarray.py +++ b/torch_np/_ndarray.py @@ -64,8 +64,13 @@ class ndarray: def __init__(self, t=None): if t is None: self.tensor = torch.Tensor() + elif isinstance(t, torch.Tensor): + self.tensor = t else: - self.tensor = torch.as_tensor(t) + raise ValueError( + "ndarray constructor is not recommended; prefer" + "either array(...) or zeros/empty(...)" + ) @property def shape(self): diff --git a/torch_np/tests/numpy_tests/core/test_numeric.py b/torch_np/tests/numpy_tests/core/test_numeric.py index 0d1b01ab..0ded85b7 100644 --- a/torch_np/tests/numpy_tests/core/test_numeric.py +++ b/torch_np/tests/numpy_tests/core/test_numeric.py @@ -2988,15 +2988,21 @@ def test_shape_mismatch_error_message(self): np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7]) -@pytest.mark.xfail(reason="TODO") class TestTensordot: def test_zero_dimension(self): # Test resolution to issue #5663 - a = np.ndarray((3,0)) - b = np.ndarray((0,4)) + a = np.zeros((3,0)) + b = np.zeros((0,4)) td = np.tensordot(a, b, (1, 0)) assert_array_equal(td, np.dot(a, b)) + + @pytest.mark.xfail(reason="no einsum") + def test_zero_dimension_einsum(self): + # Test resolution to issue #5663 + a = np.zeros((3,0)) + b = np.zeros((0,4)) + td = np.tensordot(a, b, (1, 0)) assert_array_equal(td, np.einsum('ij,jk', a, b)) def test_zero_dimensional(self): diff --git a/torch_np/tests/numpy_tests/lib/test_function_base.py b/torch_np/tests/numpy_tests/lib/test_function_base.py index 168911a2..0013da39 100644 --- a/torch_np/tests/numpy_tests/lib/test_function_base.py +++ b/torch_np/tests/numpy_tests/lib/test_function_base.py @@ -26,14 +26,15 @@ # FIXME: make from torch_np from numpy.lib import ( - bartlett, blackman, - delete, digitize, extract, gradient, hamming, hanning, - insert, interp, kaiser, msort, piecewise, place, + delete, digitize, extract, gradient, + insert, interp, msort, piecewise, place, select, setxor1d, trapz, trim_zeros, unwrap, vectorize ) from torch_np._detail._util import normalize_axis_tuple from torch_np import corrcoef, cov, i0, angle, sinc, diff, meshgrid, unique +from torch_np import flipud, hamming, hanning, kaiser, blackman, bartlett + def get_mat(n): data = np.arange(n) @@ -1701,7 +1702,6 @@ def test_period(self): assert sm_discont.dtype == wrap_uneven.dtype -@pytest.mark.xfail(reason='TODO: implement') @pytest.mark.parametrize( "dtype", np.typecodes["AllInteger"] + np.typecodes["Float"] ) @@ -1709,14 +1709,14 @@ def test_period(self): class TestFilterwindows: def test_hanning(self, dtype: str, M: int) -> None: - scalar = np.array(M, dtype=dtype)[()] + scalar = M w = hanning(scalar) - ref_dtype = np.result_type(scalar.dtype, np.float64) + ref_dtype = np.result_type(dtype, np.float64) assert w.dtype == ref_dtype # check symmetry - assert_equal(w, flipud(w)) + assert_allclose(w, flipud(w), atol=1e-15) # check known value if scalar < 1: @@ -1727,14 +1727,14 @@ def test_hanning(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 4.500, 4) def test_hamming(self, dtype: str, M: int) -> None: - scalar = np.array(M, dtype=dtype)[()] + scalar = M w = hamming(scalar) - ref_dtype = np.result_type(scalar.dtype, np.float64) + ref_dtype = np.result_type(dtype, np.float64) assert w.dtype == ref_dtype # check symmetry - assert_equal(w, flipud(w)) + assert_allclose(w, flipud(w), atol=1e-15) # check known value if scalar < 1: @@ -1745,14 +1745,14 @@ def test_hamming(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 4.9400, 4) def test_bartlett(self, dtype: str, M: int) -> None: - scalar = np.array(M, dtype=dtype)[()] + scalar = M w = bartlett(scalar) - ref_dtype = np.result_type(scalar.dtype, np.float64) + ref_dtype = np.result_type(dtype, np.float64) assert w.dtype == ref_dtype # check symmetry - assert_equal(w, flipud(w)) + assert_allclose(w, flipud(w), atol=1e-15) # check known value if scalar < 1: @@ -1763,14 +1763,14 @@ def test_bartlett(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 4.4444, 4) def test_blackman(self, dtype: str, M: int) -> None: - scalar = np.array(M, dtype=dtype)[()] + scalar = M w = blackman(scalar) - ref_dtype = np.result_type(scalar.dtype, np.float64) + ref_dtype = np.result_type(dtype, np.float64) assert w.dtype == ref_dtype # check symmetry - assert_equal(w, flipud(w)) + assert_allclose(w, flipud(w), atol=1e-15) # check known value if scalar < 1: @@ -1781,10 +1781,10 @@ def test_blackman(self, dtype: str, M: int) -> None: assert_almost_equal(np.sum(w, axis=0), 3.7800, 4) def test_kaiser(self, dtype: str, M: int) -> None: - scalar = np.array(M, dtype=dtype)[()] + scalar = M w = kaiser(scalar, 0) - ref_dtype = np.result_type(scalar.dtype, np.float64) + ref_dtype = np.result_type(dtype, np.float64) assert w.dtype == ref_dtype # check symmetry