Skip to content

Commit 792ee57

Browse files
committed
optimize split_matrix function by removing duplicate code to the extract_submatrix function, add tests
1 parent 6c92c5a commit 792ee57

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

divide_and_conquer/strassen_matrix_multiplication.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,20 @@ def split_matrix(a: list) -> tuple[list, list, list, list]:
4949
if len(a) % 2 != 0 or len(a[0]) % 2 != 0:
5050
raise Exception("Odd matrices are not supported!")
5151

52-
matrix_length = len(a)
53-
mid = matrix_length // 2
52+
def extract_submatrix(rows, cols):
53+
return [[a[i][j] for j in cols] for i in rows]
5454

55-
top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)]
56-
bot_right = [
57-
[a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length)
58-
]
55+
mid = len(a) // 2
5956

60-
top_left = [[a[i][j] for j in range(mid)] for i in range(mid)]
61-
bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)]
57+
rows_top, rows_bot = range(mid), range(mid, len(a))
58+
cols_left, cols_right = range(mid), range(mid, len(a))
6259

63-
return top_left, top_right, bot_left, bot_right
60+
return (
61+
extract_submatrix(rows_top, cols_left), # Top-left
62+
extract_submatrix(rows_top, cols_right), # Top-right
63+
extract_submatrix(rows_bot, cols_left), # Bottom-left
64+
extract_submatrix(rows_bot, cols_right), # Bottom-right
65+
)
6466

6567

6668
def matrix_dimensions(matrix: list) -> tuple[int, int]:

divide_and_conquer/tests/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import unittest
2+
from strassen_matrix_multiplication import split_matrix
3+
4+
5+
class TestSplitMatrix(unittest.TestCase):
6+
7+
def test_4x4_matrix(self):
8+
matrix = [
9+
[4, 3, 2, 4],
10+
[2, 3, 1, 1],
11+
[6, 5, 4, 3],
12+
[8, 4, 1, 6]
13+
]
14+
expected = (
15+
[[4, 3], [2, 3]],
16+
[[2, 4], [1, 1]],
17+
[[6, 5], [8, 4]],
18+
[[4, 3], [1, 6]]
19+
)
20+
self.assertEqual(split_matrix(matrix), expected)
21+
22+
def test_8x8_matrix(self):
23+
matrix = [
24+
[4, 3, 2, 4, 4, 3, 2, 4],
25+
[2, 3, 1, 1, 2, 3, 1, 1],
26+
[6, 5, 4, 3, 6, 5, 4, 3],
27+
[8, 4, 1, 6, 8, 4, 1, 6],
28+
[4, 3, 2, 4, 4, 3, 2, 4],
29+
[2, 3, 1, 1, 2, 3, 1, 1],
30+
[6, 5, 4, 3, 6, 5, 4, 3],
31+
[8, 4, 1, 6, 8, 4, 1, 6]
32+
]
33+
expected = (
34+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
35+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
36+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]],
37+
[[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]]
38+
)
39+
self.assertEqual(split_matrix(matrix), expected)
40+
41+
def test_invalid_odd_matrix(self):
42+
matrix = [
43+
[1, 2, 3],
44+
[4, 5, 6],
45+
[7, 8, 9]
46+
]
47+
with self.assertRaises(Exception):
48+
split_matrix(matrix)
49+
50+
def test_invalid_non_square_matrix(self):
51+
matrix = [
52+
[1, 2, 3, 4],
53+
[5, 6, 7, 8],
54+
[9, 10, 11, 12]
55+
]
56+
with self.assertRaises(Exception):
57+
split_matrix(matrix)
58+
59+
60+
if __name__ == "__main__":
61+
unittest.main()

0 commit comments

Comments
 (0)