2
2
import sys
3
3
import pytest
4
4
5
- from numpy .lib .shape_base import (apply_along_axis , apply_over_axes , kron , tile ,
5
+ from numpy .lib .shape_base import (apply_along_axis , apply_over_axes ,
6
6
take_along_axis , put_along_axis )
7
7
8
8
import torch_np as np
9
9
from torch_np import (column_stack , dstack , expand_dims , array_split ,
10
- split , hsplit , dsplit , vsplit ,)
10
+ split , hsplit , dsplit , vsplit , kron , tile , )
11
11
12
- from torch_np .random import rand
12
+ from torch_np .random import rand , randint
13
13
14
14
from torch_np .testing import assert_array_equal , assert_equal , assert_
15
15
from pytest import raises as assert_raises
@@ -635,7 +635,6 @@ def test_squeeze_axis_handling(self):
635
635
np .squeeze (np .array ([[1 ], [2 ], [3 ]]), axis = 0 )
636
636
637
637
638
- @pytest .mark .xfail (reason = "TODO: implement" )
639
638
class TestKron :
640
639
def test_basic (self ):
641
640
# Using 0-dimensional ndarray
@@ -666,16 +665,6 @@ def test_basic(self):
666
665
k = np .array ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
667
666
assert_array_equal (np .kron (a , b ), k )
668
667
669
- def test_return_type (self ):
670
- class myarray (np .ndarray ):
671
- __array_priority__ = 1.0
672
-
673
- a = np .ones ([2 , 2 ])
674
- ma = myarray (a .shape , a .dtype , a .data )
675
- assert_equal (type (kron (a , a )), np .ndarray )
676
- assert_equal (type (kron (ma , ma )), myarray )
677
- assert_equal (type (kron (a , ma )), myarray )
678
- assert_equal (type (kron (ma , a )), myarray )
679
668
680
669
@pytest .mark .parametrize (
681
670
"shape_a,shape_b" , [
@@ -698,7 +687,6 @@ def test_kron_shape(self, shape_a, shape_b):
698
687
k .shape , expected_shape ), "Unexpected shape from kron"
699
688
700
689
701
- @pytest .mark .xfail (reason = "TODO: implement" )
702
690
class TestTile :
703
691
def test_basic (self ):
704
692
a = np .array ([0 , 1 , 2 ])
@@ -726,8 +714,6 @@ def test_empty(self):
726
714
assert_equal (d , (3 , 2 , 0 ))
727
715
728
716
def test_kroncompare (self ):
729
- from numpy .random import randint
730
-
731
717
reps = [(2 ,), (1 , 2 ), (2 , 1 ), (2 , 2 ), (2 , 3 , 2 ), (3 , 2 )]
732
718
shape = [(3 ,), (2 , 3 ), (3 , 4 , 3 ), (3 , 2 , 3 ), (4 , 3 , 2 , 4 ), (2 , 2 )]
733
719
for s in shape :
0 commit comments