Skip to content

Reduce the complexity of graphs/minimum_spanning_tree_prims.py #7952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[flake8]
max-line-length = 88
# max-complexity should be 10
max-complexity = 21
max-complexity = 20
extend-ignore =
# Formatting style for `black`
E203 # Whitespace before ':'
Expand Down
127 changes: 75 additions & 52 deletions graphs/minimum_spanning_tree_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,45 @@
from collections import defaultdict


def prisms_algorithm(l): # noqa: E741
class Heap:
def __init__(self):
self.node_position = []

node_position = []
def get_position(self, vertex):
return self.node_position[vertex]

def get_position(vertex):
return node_position[vertex]
def set_position(self, vertex, pos):
self.node_position[vertex] = pos

def set_position(vertex, pos):
node_position[vertex] = pos

def top_to_bottom(heap, start, size, positions):
def top_to_bottom(self, heap, start, size, positions):
if start > size // 2 - 1:
return
else:
if 2 * start + 2 >= size:
m = 2 * start + 1
smallest_child = 2 * start + 1
else:
if heap[2 * start + 1] < heap[2 * start + 2]:
m = 2 * start + 1
smallest_child = 2 * start + 1
else:
m = 2 * start + 2
if heap[m] < heap[start]:
temp, temp1 = heap[m], positions[m]
heap[m], positions[m] = heap[start], positions[start]
smallest_child = 2 * start + 2
if heap[smallest_child] < heap[start]:
temp, temp1 = heap[smallest_child], positions[smallest_child]
heap[smallest_child], positions[smallest_child] = (
heap[start],
positions[start],
)
heap[start], positions[start] = temp, temp1

temp = get_position(positions[m])
set_position(positions[m], get_position(positions[start]))
set_position(positions[start], temp)
temp = self.get_position(positions[smallest_child])
self.set_position(
positions[smallest_child], self.get_position(positions[start])
)
self.set_position(positions[start], temp)

top_to_bottom(heap, m, size, positions)
self.top_to_bottom(heap, smallest_child, size, positions)

# Update function if value of any node in min-heap decreases
def bottom_to_top(val, index, heap, position):
def bottom_to_top(self, val, index, heap, position):
temp = position[index]

while index != 0:
Expand All @@ -47,70 +52,88 @@ def bottom_to_top(val, index, heap, position):
if val < heap[parent]:
heap[index] = heap[parent]
position[index] = position[parent]
set_position(position[parent], index)
self.set_position(position[parent], index)
else:
heap[index] = val
position[index] = temp
set_position(temp, index)
self.set_position(temp, index)
break
index = parent
else:
heap[0] = val
position[0] = temp
set_position(temp, 0)
self.set_position(temp, 0)

def heapify(heap, positions):
def heapify(self, heap, positions):
start = len(heap) // 2 - 1
for i in range(start, -1, -1):
top_to_bottom(heap, i, len(heap), positions)
self.top_to_bottom(heap, i, len(heap), positions)

def delete_minimum(heap, positions):
def delete_minimum(self, heap, positions):
temp = positions[0]
heap[0] = sys.maxsize
top_to_bottom(heap, 0, len(heap), positions)
self.top_to_bottom(heap, 0, len(heap), positions)
return temp

visited = [0 for i in range(len(l))]
nbr_tv = [-1 for i in range(len(l))] # Neighboring Tree Vertex of selected vertex

def prisms_algorithm(adjacency_list):
"""
>>> adjacency_list = {0: [[1, 1], [3, 3]],
... 1: [[0, 1], [2, 6], [3, 5], [4, 1]],
... 2: [[1, 6], [4, 5], [5, 2]],
... 3: [[0, 3], [1, 5], [4, 1]],
... 4: [[1, 1], [2, 5], [3, 1], [5, 4]],
... 5: [[2, 2], [4, 4]]}
>>> prisms_algorithm(adjacency_list)
[(0, 1), (1, 4), (4, 3), (4, 5), (5, 2)]
"""

heap = Heap()

visited = [0] * len(adjacency_list)
nbr_tv = [-1] * len(adjacency_list) # Neighboring Tree Vertex of selected vertex
# Minimum Distance of explored vertex with neighboring vertex of partial tree
# formed in graph
distance_tv = [] # Heap of Distance of vertices from their neighboring vertex
positions = []

for x in range(len(l)):
p = sys.maxsize
distance_tv.append(p)
positions.append(x)
node_position.append(x)
for vertex in range(len(adjacency_list)):
distance_tv.append(sys.maxsize)
positions.append(vertex)
heap.node_position.append(vertex)

tree_edges = []
visited[0] = 1
distance_tv[0] = sys.maxsize
for x in l[0]:
nbr_tv[x[0]] = 0
distance_tv[x[0]] = x[1]
heapify(distance_tv, positions)
for neighbor, distance in adjacency_list[0]:
nbr_tv[neighbor] = 0
distance_tv[neighbor] = distance
heap.heapify(distance_tv, positions)

for _ in range(1, len(l)):
vertex = delete_minimum(distance_tv, positions)
for _ in range(1, len(adjacency_list)):
vertex = heap.delete_minimum(distance_tv, positions)
if visited[vertex] == 0:
tree_edges.append((nbr_tv[vertex], vertex))
visited[vertex] = 1
for v in l[vertex]:
if visited[v[0]] == 0 and v[1] < distance_tv[get_position(v[0])]:
distance_tv[get_position(v[0])] = v[1]
bottom_to_top(v[1], get_position(v[0]), distance_tv, positions)
nbr_tv[v[0]] = vertex
for neighbor, distance in adjacency_list[vertex]:
if (
visited[neighbor] == 0
and distance < distance_tv[heap.get_position(neighbor)]
):
distance_tv[heap.get_position(neighbor)] = distance
heap.bottom_to_top(
distance, heap.get_position(neighbor), distance_tv, positions
)
nbr_tv[neighbor] = vertex
return tree_edges


if __name__ == "__main__": # pragma: no cover
# < --------- Prims Algorithm --------- >
n = int(input("Enter number of vertices: ").strip())
e = int(input("Enter number of edges: ").strip())
adjlist = defaultdict(list)
for x in range(e):
l = [int(x) for x in input().strip().split()] # noqa: E741
adjlist[l[0]].append([l[1], l[2]])
adjlist[l[1]].append([l[0], l[2]])
print(prisms_algorithm(adjlist))
edges_number = int(input("Enter number of edges: ").strip())
adjacency_list = defaultdict(list)
for _ in range(edges_number):
edge = [int(x) for x in input().strip().split()]
adjacency_list[edge[0]].append([edge[1], edge[2]])
adjacency_list[edge[1]].append([edge[0], edge[2]])
print(prisms_algorithm(adjacency_list))