Skip to content

Add filter windows, tensordot #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions autogen/numpy_api_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@ def binary_repr(num, width=None):
raise NotImplementedError


def blackman(M):
raise NotImplementedError


def block(arrays):
raise NotImplementedError

Expand Down Expand Up @@ -337,14 +333,6 @@ def gradient(f, *varargs, axis=None, edge_order=1):
raise NotImplementedError


def hamming(M):
raise NotImplementedError


def hanning(M):
raise NotImplementedError


def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
raise NotImplementedError

Expand Down Expand Up @@ -409,10 +397,6 @@ def ix_(*args):
raise NotImplementedError


def kaiser(M, beta):
raise NotImplementedError


def lexsort(keys, axis=-1):
raise NotImplementedError

Expand Down Expand Up @@ -759,10 +743,6 @@ def take(a, indices, axis=None, out=None, mode="raise"):
raise NotImplementedError


def tensordot(a, b, axes=2):
raise NotImplementedError


def trapz(y, x=None, dx=1.0, axis=-1):
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion torch_np/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

inf = float("inf")
nan = float("nan")
from math import pi # isort: skip
from math import pi, e # isort: skip

False_ = asarray(False, bool_)
True_ = asarray(True, bool_)
40 changes: 40 additions & 0 deletions torch_np/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,13 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
return result.item()


@normalizer
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
if isinstance(axes, (list, tuple)):
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
return torch.tensordot(a, b, dims=axes)


@normalizer
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None):
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
Expand Down Expand Up @@ -1850,3 +1857,36 @@ def __getitem__(self, item):

index_exp = IndexExpression(maketuple=True)
s_ = IndexExpression(maketuple=False)


# ### Filter windows ###


@normalizer
def hamming(M):
dtype = _dtypes_impl.default_float_dtype
return torch.hamming_window(M, periodic=False, dtype=dtype)


@normalizer
def hanning(M):
dtype = _dtypes_impl.default_float_dtype
return torch.hann_window(M, periodic=False, dtype=dtype)


@normalizer
def kaiser(M, beta):
dtype = _dtypes_impl.default_float_dtype
return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)


@normalizer
def blackman(M):
dtype = _dtypes_impl.default_float_dtype
return torch.blackman_window(M, periodic=False, dtype=dtype)


@normalizer
def bartlett(M):
dtype = _dtypes_impl.default_float_dtype
return torch.bartlett_window(M, periodic=False, dtype=dtype)
7 changes: 6 additions & 1 deletion torch_np/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ class ndarray:
def __init__(self, t=None):
if t is None:
self.tensor = torch.Tensor()
elif isinstance(t, torch.Tensor):
self.tensor = t
else:
self.tensor = torch.as_tensor(t)
raise ValueError(
"ndarray constructor is not recommended; prefer"
"either array(...) or zeros/empty(...)"
)

@property
def shape(self):
Expand Down
12 changes: 9 additions & 3 deletions torch_np/tests/numpy_tests/core/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2988,15 +2988,21 @@ def test_shape_mismatch_error_message(self):
np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7])


@pytest.mark.xfail(reason="TODO")
class TestTensordot:

def test_zero_dimension(self):
# Test resolution to issue #5663
a = np.ndarray((3,0))
b = np.ndarray((0,4))
a = np.zeros((3,0))
b = np.zeros((0,4))
td = np.tensordot(a, b, (1, 0))
assert_array_equal(td, np.dot(a, b))

@pytest.mark.xfail(reason="no einsum")
def test_zero_dimension_einsum(self):
# Test resolution to issue #5663
a = np.zeros((3,0))
b = np.zeros((0,4))
td = np.tensordot(a, b, (1, 0))
assert_array_equal(td, np.einsum('ij,jk', a, b))

def test_zero_dimensional(self):
Expand Down
36 changes: 18 additions & 18 deletions torch_np/tests/numpy_tests/lib/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@

# FIXME: make from torch_np
from numpy.lib import (
bartlett, blackman,
delete, digitize, extract, gradient, hamming, hanning,
insert, interp, kaiser, msort, piecewise, place,
delete, digitize, extract, gradient,
insert, interp, msort, piecewise, place,
select, setxor1d, trapz, trim_zeros, unwrap, vectorize
)
from torch_np._detail._util import normalize_axis_tuple

from torch_np import corrcoef, cov, i0, angle, sinc, diff, meshgrid, unique
from torch_np import flipud, hamming, hanning, kaiser, blackman, bartlett


def get_mat(n):
data = np.arange(n)
Expand Down Expand Up @@ -1701,22 +1702,21 @@ def test_period(self):
assert sm_discont.dtype == wrap_uneven.dtype


@pytest.mark.xfail(reason='TODO: implement')
@pytest.mark.parametrize(
"dtype", np.typecodes["AllInteger"] + np.typecodes["Float"]
)
@pytest.mark.parametrize("M", [0, 1, 10])
class TestFilterwindows:

def test_hanning(self, dtype: str, M: int) -> None:
scalar = np.array(M, dtype=dtype)[()]
scalar = M

w = hanning(scalar)
ref_dtype = np.result_type(scalar.dtype, np.float64)
ref_dtype = np.result_type(dtype, np.float64)
assert w.dtype == ref_dtype

# check symmetry
assert_equal(w, flipud(w))
assert_allclose(w, flipud(w), atol=1e-15)

# check known value
if scalar < 1:
Expand All @@ -1727,14 +1727,14 @@ def test_hanning(self, dtype: str, M: int) -> None:
assert_almost_equal(np.sum(w, axis=0), 4.500, 4)

def test_hamming(self, dtype: str, M: int) -> None:
scalar = np.array(M, dtype=dtype)[()]
scalar = M

w = hamming(scalar)
ref_dtype = np.result_type(scalar.dtype, np.float64)
ref_dtype = np.result_type(dtype, np.float64)
assert w.dtype == ref_dtype

# check symmetry
assert_equal(w, flipud(w))
assert_allclose(w, flipud(w), atol=1e-15)

# check known value
if scalar < 1:
Expand All @@ -1745,14 +1745,14 @@ def test_hamming(self, dtype: str, M: int) -> None:
assert_almost_equal(np.sum(w, axis=0), 4.9400, 4)

def test_bartlett(self, dtype: str, M: int) -> None:
scalar = np.array(M, dtype=dtype)[()]
scalar = M

w = bartlett(scalar)
ref_dtype = np.result_type(scalar.dtype, np.float64)
ref_dtype = np.result_type(dtype, np.float64)
assert w.dtype == ref_dtype

# check symmetry
assert_equal(w, flipud(w))
assert_allclose(w, flipud(w), atol=1e-15)

# check known value
if scalar < 1:
Expand All @@ -1763,14 +1763,14 @@ def test_bartlett(self, dtype: str, M: int) -> None:
assert_almost_equal(np.sum(w, axis=0), 4.4444, 4)

def test_blackman(self, dtype: str, M: int) -> None:
scalar = np.array(M, dtype=dtype)[()]
scalar = M

w = blackman(scalar)
ref_dtype = np.result_type(scalar.dtype, np.float64)
ref_dtype = np.result_type(dtype, np.float64)
assert w.dtype == ref_dtype

# check symmetry
assert_equal(w, flipud(w))
assert_allclose(w, flipud(w), atol=1e-15)

# check known value
if scalar < 1:
Expand All @@ -1781,10 +1781,10 @@ def test_blackman(self, dtype: str, M: int) -> None:
assert_almost_equal(np.sum(w, axis=0), 3.7800, 4)

def test_kaiser(self, dtype: str, M: int) -> None:
scalar = np.array(M, dtype=dtype)[()]
scalar = M

w = kaiser(scalar, 0)
ref_dtype = np.result_type(scalar.dtype, np.float64)
ref_dtype = np.result_type(dtype, np.float64)
assert w.dtype == ref_dtype

# check symmetry
Expand Down