Skip to content

Commit a02de96

Browse files
Reduce the complexity of graphs/minimum_spanning_tree_prims.py (#7952)
* Lower the --max-complexity threshold in the file .flake8 * Add test * Reduce the complexity of graphs/minimum_spanning_tree_prims.py * Remove backslashes * Remove # noqa: E741 * Fix the flake8 E741 issues * Refactor * Fix
1 parent db5215f commit a02de96

File tree

2 files changed

+76
-53
lines changed

2 files changed

+76
-53
lines changed

Diff for: .flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[flake8]
22
max-line-length = 88
33
# max-complexity should be 10
4-
max-complexity = 21
4+
max-complexity = 20
55
extend-ignore =
66
# Formatting style for `black`
77
E203 # Whitespace before ':'

Diff for: graphs/minimum_spanning_tree_prims.py

+75-52
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,45 @@
22
from collections import defaultdict
33

44

5-
def prisms_algorithm(l): # noqa: E741
5+
class Heap:
6+
def __init__(self):
7+
self.node_position = []
68

7-
node_position = []
9+
def get_position(self, vertex):
10+
return self.node_position[vertex]
811

9-
def get_position(vertex):
10-
return node_position[vertex]
12+
def set_position(self, vertex, pos):
13+
self.node_position[vertex] = pos
1114

12-
def set_position(vertex, pos):
13-
node_position[vertex] = pos
14-
15-
def top_to_bottom(heap, start, size, positions):
15+
def top_to_bottom(self, heap, start, size, positions):
1616
if start > size // 2 - 1:
1717
return
1818
else:
1919
if 2 * start + 2 >= size:
20-
m = 2 * start + 1
20+
smallest_child = 2 * start + 1
2121
else:
2222
if heap[2 * start + 1] < heap[2 * start + 2]:
23-
m = 2 * start + 1
23+
smallest_child = 2 * start + 1
2424
else:
25-
m = 2 * start + 2
26-
if heap[m] < heap[start]:
27-
temp, temp1 = heap[m], positions[m]
28-
heap[m], positions[m] = heap[start], positions[start]
25+
smallest_child = 2 * start + 2
26+
if heap[smallest_child] < heap[start]:
27+
temp, temp1 = heap[smallest_child], positions[smallest_child]
28+
heap[smallest_child], positions[smallest_child] = (
29+
heap[start],
30+
positions[start],
31+
)
2932
heap[start], positions[start] = temp, temp1
3033

31-
temp = get_position(positions[m])
32-
set_position(positions[m], get_position(positions[start]))
33-
set_position(positions[start], temp)
34+
temp = self.get_position(positions[smallest_child])
35+
self.set_position(
36+
positions[smallest_child], self.get_position(positions[start])
37+
)
38+
self.set_position(positions[start], temp)
3439

35-
top_to_bottom(heap, m, size, positions)
40+
self.top_to_bottom(heap, smallest_child, size, positions)
3641

3742
# Update function if value of any node in min-heap decreases
38-
def bottom_to_top(val, index, heap, position):
43+
def bottom_to_top(self, val, index, heap, position):
3944
temp = position[index]
4045

4146
while index != 0:
@@ -47,70 +52,88 @@ def bottom_to_top(val, index, heap, position):
4752
if val < heap[parent]:
4853
heap[index] = heap[parent]
4954
position[index] = position[parent]
50-
set_position(position[parent], index)
55+
self.set_position(position[parent], index)
5156
else:
5257
heap[index] = val
5358
position[index] = temp
54-
set_position(temp, index)
59+
self.set_position(temp, index)
5560
break
5661
index = parent
5762
else:
5863
heap[0] = val
5964
position[0] = temp
60-
set_position(temp, 0)
65+
self.set_position(temp, 0)
6166

62-
def heapify(heap, positions):
67+
def heapify(self, heap, positions):
6368
start = len(heap) // 2 - 1
6469
for i in range(start, -1, -1):
65-
top_to_bottom(heap, i, len(heap), positions)
70+
self.top_to_bottom(heap, i, len(heap), positions)
6671

67-
def delete_minimum(heap, positions):
72+
def delete_minimum(self, heap, positions):
6873
temp = positions[0]
6974
heap[0] = sys.maxsize
70-
top_to_bottom(heap, 0, len(heap), positions)
75+
self.top_to_bottom(heap, 0, len(heap), positions)
7176
return temp
7277

73-
visited = [0 for i in range(len(l))]
74-
nbr_tv = [-1 for i in range(len(l))] # Neighboring Tree Vertex of selected vertex
78+
79+
def prisms_algorithm(adjacency_list):
80+
"""
81+
>>> adjacency_list = {0: [[1, 1], [3, 3]],
82+
... 1: [[0, 1], [2, 6], [3, 5], [4, 1]],
83+
... 2: [[1, 6], [4, 5], [5, 2]],
84+
... 3: [[0, 3], [1, 5], [4, 1]],
85+
... 4: [[1, 1], [2, 5], [3, 1], [5, 4]],
86+
... 5: [[2, 2], [4, 4]]}
87+
>>> prisms_algorithm(adjacency_list)
88+
[(0, 1), (1, 4), (4, 3), (4, 5), (5, 2)]
89+
"""
90+
91+
heap = Heap()
92+
93+
visited = [0] * len(adjacency_list)
94+
nbr_tv = [-1] * len(adjacency_list) # Neighboring Tree Vertex of selected vertex
7595
# Minimum Distance of explored vertex with neighboring vertex of partial tree
7696
# formed in graph
7797
distance_tv = [] # Heap of Distance of vertices from their neighboring vertex
7898
positions = []
7999

80-
for x in range(len(l)):
81-
p = sys.maxsize
82-
distance_tv.append(p)
83-
positions.append(x)
84-
node_position.append(x)
100+
for vertex in range(len(adjacency_list)):
101+
distance_tv.append(sys.maxsize)
102+
positions.append(vertex)
103+
heap.node_position.append(vertex)
85104

86105
tree_edges = []
87106
visited[0] = 1
88107
distance_tv[0] = sys.maxsize
89-
for x in l[0]:
90-
nbr_tv[x[0]] = 0
91-
distance_tv[x[0]] = x[1]
92-
heapify(distance_tv, positions)
108+
for neighbor, distance in adjacency_list[0]:
109+
nbr_tv[neighbor] = 0
110+
distance_tv[neighbor] = distance
111+
heap.heapify(distance_tv, positions)
93112

94-
for _ in range(1, len(l)):
95-
vertex = delete_minimum(distance_tv, positions)
113+
for _ in range(1, len(adjacency_list)):
114+
vertex = heap.delete_minimum(distance_tv, positions)
96115
if visited[vertex] == 0:
97116
tree_edges.append((nbr_tv[vertex], vertex))
98117
visited[vertex] = 1
99-
for v in l[vertex]:
100-
if visited[v[0]] == 0 and v[1] < distance_tv[get_position(v[0])]:
101-
distance_tv[get_position(v[0])] = v[1]
102-
bottom_to_top(v[1], get_position(v[0]), distance_tv, positions)
103-
nbr_tv[v[0]] = vertex
118+
for neighbor, distance in adjacency_list[vertex]:
119+
if (
120+
visited[neighbor] == 0
121+
and distance < distance_tv[heap.get_position(neighbor)]
122+
):
123+
distance_tv[heap.get_position(neighbor)] = distance
124+
heap.bottom_to_top(
125+
distance, heap.get_position(neighbor), distance_tv, positions
126+
)
127+
nbr_tv[neighbor] = vertex
104128
return tree_edges
105129

106130

107131
if __name__ == "__main__": # pragma: no cover
108132
# < --------- Prims Algorithm --------- >
109-
n = int(input("Enter number of vertices: ").strip())
110-
e = int(input("Enter number of edges: ").strip())
111-
adjlist = defaultdict(list)
112-
for x in range(e):
113-
l = [int(x) for x in input().strip().split()] # noqa: E741
114-
adjlist[l[0]].append([l[1], l[2]])
115-
adjlist[l[1]].append([l[0], l[2]])
116-
print(prisms_algorithm(adjlist))
133+
edges_number = int(input("Enter number of edges: ").strip())
134+
adjacency_list = defaultdict(list)
135+
for _ in range(edges_number):
136+
edge = [int(x) for x in input().strip().split()]
137+
adjacency_list[edge[0]].append([edge[1], edge[2]])
138+
adjacency_list[edge[1]].append([edge[0], edge[2]])
139+
print(prisms_algorithm(adjacency_list))

0 commit comments

Comments
 (0)