Skip to content

Commit 598359e

Browse files
committed
fix test file issues
1 parent 792ee57 commit 598359e

File tree

1 file changed

+60
-61
lines changed

1 file changed

+60
-61
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,60 @@
1-
import unittest
2-
from strassen_matrix_multiplication import split_matrix
3-
4-
5-
class TestSplitMatrix(unittest.TestCase):
6-
7-
def test_4x4_matrix(self):
8-
matrix = [
9-
[4, 3, 2, 4],
10-
[2, 3, 1, 1],
11-
[6, 5, 4, 3],
12-
[8, 4, 1, 6]
13-
]
14-
expected = (
15-
[[4, 3], [2, 3]],
16-
[[2, 4], [1, 1]],
17-
[[6, 5], [8, 4]],
18-
[[4, 3], [1, 6]]
19-
)
20-
self.assertEqual(split_matrix(matrix), expected)
21-
22-
def test_8x8_matrix(self):
23-
matrix = [
24-
[4, 3, 2, 4, 4, 3, 2, 4],
25-
[2, 3, 1, 1, 2, 3, 1, 1],
26-
[6, 5, 4, 3, 6, 5, 4, 3],
27-
[8, 4, 1, 6, 8, 4, 1, 6],
28-
[4, 3, 2, 4, 4, 3, 2, 4],
29-
[2, 3, 1, 1, 2, 3, 1, 1],
30-
[6, 5, 4, 3, 6, 5, 4, 3],
31-
[8, 4, 1, 6, 8, 4, 1, 6]
32-
]
33-
expected = (
34-
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
35-
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
36-
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
37-
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]]
38-
)
39-
self.assertEqual(split_matrix(matrix), expected)
40-
41-
def test_invalid_odd_matrix(self):
42-
matrix = [
43-
[1, 2, 3],
44-
[4, 5, 6],
45-
[7, 8, 9]
46-
]
47-
with self.assertRaises(Exception):
48-
split_matrix(matrix)
49-
50-
def test_invalid_non_square_matrix(self):
51-
matrix = [
52-
[1, 2, 3, 4],
53-
[5, 6, 7, 8],
54-
[9, 10, 11, 12]
55-
]
56-
with self.assertRaises(Exception):
57-
split_matrix(matrix)
58-
59-
60-
if __name__ == "__main__":
61-
unittest.main()
1+
import pytest
2+
from divide_and_conquer.strassen_matrix_multiplication import split_matrix
3+
4+
5+
def test_4x4_matrix():
6+
matrix = [
7+
[4, 3, 2, 4],
8+
[2, 3, 1, 1],
9+
[6, 5, 4, 3],
10+
[8, 4, 1, 6]
11+
]
12+
expected = (
13+
[[4, 3], [2, 3]],
14+
[[2, 4], [1, 1]],
15+
[[6, 5], [8, 4]],
16+
[[4, 3], [1, 6]]
17+
)
18+
assert split_matrix(matrix) == expected
19+
20+
21+
def test_8x8_matrix():
22+
matrix = [
23+
[4, 3, 2, 4, 4, 3, 2, 4],
24+
[2, 3, 1, 1, 2, 3, 1, 1],
25+
[6, 5, 4, 3, 6, 5, 4, 3],
26+
[8, 4, 1, 6, 8, 4, 1, 6],
27+
[4, 3, 2, 4, 4, 3, 2, 4],
28+
[2, 3, 1, 1, 2, 3, 1, 1],
29+
[6, 5, 4, 3, 6, 5, 4, 3],
30+
[8, 4, 1, 6, 8, 4, 1, 6]
31+
]
32+
expected = (
33+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
34+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
35+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
36+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]]
37+
)
38+
assert split_matrix(matrix) == expected
39+
40+
41+
def test_invalid_odd_matrix():
42+
matrix = [
43+
[1, 2, 3],
44+
[4, 5, 6],
45+
[7, 8, 9]
46+
]
47+
with pytest.raises(Exception, match="Odd matrices are not supported!"):
48+
split_matrix(matrix)
49+
50+
51+
def test_invalid_non_square_matrix():
52+
matrix = [
53+
[1, 2, 3, 4],
54+
[5, 6, 7, 8],
55+
[9, 10, 11, 12],
56+
[13, 14, 15, 16],
57+
[17, 18, 19, 20]
58+
]
59+
with pytest.raises(Exception, match="Odd matrices are not supported!"):
60+
split_matrix(matrix)

0 commit comments

Comments
 (0)