Skip to content

Commit 51a8d12

Browse files
committed
ENH: tile, kron
1 parent d16d36a commit 51a8d12

File tree

3 files changed

+18
-21
lines changed

3 files changed

+18
-21
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -481,10 +481,6 @@ def kaiser(M, beta):
481481
raise NotImplementedError
482482

483483

484-
def kron(a, b):
485-
raise NotImplementedError
486-
487-
488484
def lexsort(keys, axis=-1):
489485
raise NotImplementedError
490486

torch_np/_wrapper.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,21 @@ def dsplit(ary, indices_or_sections):
215215
return tuple(maybe_set_base(_, base) for _ in result)
216216

217217

218+
def kron(a, b):
219+
a_tensor, b_tensor = _helpers.to_tensors(a, b)
220+
result = torch.kron(a_tensor, b_tensor)
221+
return asarray(result)
222+
223+
224+
def tile(A, reps):
225+
a_tensor = asarray(A).get()
226+
if isinstance(reps, int):
227+
reps = (reps,)
228+
229+
result = torch.tile(a_tensor, reps)
230+
return asarray(result)
231+
232+
218233
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0):
219234
if axis != 0 or retstep or not endpoint:
220235
raise NotImplementedError

torch_np/tests/numpy_tests/lib/test_shape_base_.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import sys
33
import pytest
44

5-
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes, kron, tile,
5+
from numpy.lib.shape_base import (apply_along_axis, apply_over_axes,
66
take_along_axis, put_along_axis)
77

88
import torch_np as np
99
from torch_np import (column_stack, dstack, expand_dims, array_split,
10-
split, hsplit, dsplit, vsplit,)
10+
split, hsplit, dsplit, vsplit, kron, tile,)
1111

12-
from torch_np.random import rand
12+
from torch_np.random import rand, randint
1313

1414
from torch_np.testing import assert_array_equal, assert_equal, assert_
1515
from pytest import raises as assert_raises
@@ -635,7 +635,6 @@ def test_squeeze_axis_handling(self):
635635
np.squeeze(np.array([[1], [2], [3]]), axis=0)
636636

637637

638-
@pytest.mark.xfail(reason="TODO: implement")
639638
class TestKron:
640639
def test_basic(self):
641640
# Using 0-dimensional ndarray
@@ -666,16 +665,6 @@ def test_basic(self):
666665
k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
667666
assert_array_equal(np.kron(a, b), k)
668667

669-
def test_return_type(self):
670-
class myarray(np.ndarray):
671-
__array_priority__ = 1.0
672-
673-
a = np.ones([2, 2])
674-
ma = myarray(a.shape, a.dtype, a.data)
675-
assert_equal(type(kron(a, a)), np.ndarray)
676-
assert_equal(type(kron(ma, ma)), myarray)
677-
assert_equal(type(kron(a, ma)), myarray)
678-
assert_equal(type(kron(ma, a)), myarray)
679668

680669
@pytest.mark.parametrize(
681670
"shape_a,shape_b", [
@@ -698,7 +687,6 @@ def test_kron_shape(self, shape_a, shape_b):
698687
k.shape, expected_shape), "Unexpected shape from kron"
699688

700689

701-
@pytest.mark.xfail(reason="TODO: implement")
702690
class TestTile:
703691
def test_basic(self):
704692
a = np.array([0, 1, 2])
@@ -726,8 +714,6 @@ def test_empty(self):
726714
assert_equal(d, (3, 2, 0))
727715

728716
def test_kroncompare(self):
729-
from numpy.random import randint
730-
731717
reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
732718
shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
733719
for s in shape:

0 commit comments

Comments
 (0)