Skip to content

Commit 8998001

Browse files
ruppysuppygithub-actions
and
github-actions
authored
feat: added prim's algorithm v2 (TheAlgorithms#2742)
* feat: added prim's algorithm v2 * updating DIRECTORY.md * chore: small tweaks * fixup! Format Python code with psf/black push * chore: added algorithm descriptor Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
1 parent ec00ebf commit 8998001

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed

Diff for: DIRECTORY.md

+1
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@
287287
* [Minimum Spanning Tree Kruskal](https://github.com/TheAlgorithms/Python/blob/master/graphs/minimum_spanning_tree_kruskal.py)
288288
* [Minimum Spanning Tree Kruskal2](https://github.com/TheAlgorithms/Python/blob/master/graphs/minimum_spanning_tree_kruskal2.py)
289289
* [Minimum Spanning Tree Prims](https://github.com/TheAlgorithms/Python/blob/master/graphs/minimum_spanning_tree_prims.py)
290+
* [Minimum Spanning Tree Prims2](https://github.com/TheAlgorithms/Python/blob/master/graphs/minimum_spanning_tree_prims2.py)
290291
* [Multi Heuristic Astar](https://github.com/TheAlgorithms/Python/blob/master/graphs/multi_heuristic_astar.py)
291292
* [Page Rank](https://github.com/TheAlgorithms/Python/blob/master/graphs/page_rank.py)
292293
* [Prim](https://github.com/TheAlgorithms/Python/blob/master/graphs/prim.py)

Diff for: graphs/minimum_spanning_tree_prims2.py

+271
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""
2+
Prim's (also known as Jarník's) algorithm is a greedy algorithm that finds a minimum
3+
spanning tree for a weighted undirected graph. This means it finds a subset of the
4+
edges that forms a tree that includes every vertex, where the total weight of all the
5+
edges in the tree is minimized. The algorithm operates by building this tree one vertex
6+
at a time, from an arbitrary starting vertex, at each step adding the cheapest possible
7+
connection from the tree to another vertex.
8+
"""
9+
10+
from sys import maxsize
11+
from typing import Dict, Optional, Tuple, Union
12+
13+
14+
def get_parent_position(position: int) -> int:
15+
"""
16+
heap helper function get the position of the parent of the current node
17+
18+
>>> get_parent_position(1)
19+
0
20+
>>> get_parent_position(2)
21+
0
22+
"""
23+
return (position - 1) // 2
24+
25+
26+
def get_child_left_position(position: int) -> int:
27+
"""
28+
heap helper function get the position of the left child of the current node
29+
30+
>>> get_child_left_position(0)
31+
1
32+
"""
33+
return (2 * position) + 1
34+
35+
36+
def get_child_right_position(position: int) -> int:
37+
"""
38+
heap helper function get the position of the right child of the current node
39+
40+
>>> get_child_right_position(0)
41+
2
42+
"""
43+
return (2 * position) + 2
44+
45+
46+
class MinPriorityQueue:
47+
"""
48+
Minimum Priority Queue Class
49+
50+
Functions:
51+
is_empty: function to check if the priority queue is empty
52+
push: function to add an element with given priority to the queue
53+
extract_min: function to remove and return the element with lowest weight (highest
54+
priority)
55+
update_key: function to update the weight of the given key
56+
_bubble_up: helper function to place a node at the proper position (upward
57+
movement)
58+
_bubble_down: helper function to place a node at the proper position (downward
59+
movement)
60+
_swap_nodes: helper function to swap the nodes at the given positions
61+
62+
>>> queue = MinPriorityQueue()
63+
64+
>>> queue.push(1, 1000)
65+
>>> queue.push(2, 100)
66+
>>> queue.push(3, 4000)
67+
>>> queue.push(4, 3000)
68+
69+
>>> print(queue.extract_min())
70+
2
71+
72+
>>> queue.update_key(4, 50)
73+
74+
>>> print(queue.extract_min())
75+
4
76+
>>> print(queue.extract_min())
77+
1
78+
>>> print(queue.extract_min())
79+
3
80+
"""
81+
82+
def __init__(self) -> None:
83+
self.heap = []
84+
self.position_map = {}
85+
self.elements = 0
86+
87+
def __len__(self) -> int:
88+
return self.elements
89+
90+
def __repr__(self) -> str:
91+
return str(self.heap)
92+
93+
def is_empty(self) -> bool:
94+
# Check if the priority queue is empty
95+
return self.elements == 0
96+
97+
def push(self, elem: Union[int, str], weight: int) -> None:
98+
# Add an element with given priority to the queue
99+
self.heap.append((elem, weight))
100+
self.position_map[elem] = self.elements
101+
self.elements += 1
102+
self._bubble_up(elem)
103+
104+
def extract_min(self) -> Union[int, str]:
105+
# Remove and return the element with lowest weight (highest priority)
106+
if self.elements > 1:
107+
self._swap_nodes(0, self.elements - 1)
108+
elem, _ = self.heap.pop()
109+
del self.position_map[elem]
110+
self.elements -= 1
111+
if self.elements > 0:
112+
bubble_down_elem, _ = self.heap[0]
113+
self._bubble_down(bubble_down_elem)
114+
return elem
115+
116+
def update_key(self, elem: Union[int, str], weight: int) -> None:
117+
# Update the weight of the given key
118+
position = self.position_map[elem]
119+
self.heap[position] = (elem, weight)
120+
if position > 0:
121+
parent_position = get_parent_position(position)
122+
_, parent_weight = self.heap[parent_position]
123+
if parent_weight > weight:
124+
self._bubble_up(elem)
125+
else:
126+
self._bubble_down(elem)
127+
else:
128+
self._bubble_down(elem)
129+
130+
def _bubble_up(self, elem: Union[int, str]) -> None:
131+
# Place a node at the proper position (upward movement) [to be used internally
132+
# only]
133+
curr_pos = self.position_map[elem]
134+
if curr_pos == 0:
135+
return
136+
parent_position = get_parent_position(curr_pos)
137+
_, weight = self.heap[curr_pos]
138+
_, parent_weight = self.heap[parent_position]
139+
if parent_weight > weight:
140+
self._swap_nodes(parent_position, curr_pos)
141+
return self._bubble_up(elem)
142+
return
143+
144+
def _bubble_down(self, elem: Union[int, str]) -> None:
145+
# Place a node at the proper position (downward movement) [to be used
146+
# internally only]
147+
curr_pos = self.position_map[elem]
148+
_, weight = self.heap[curr_pos]
149+
child_left_position = get_child_left_position(curr_pos)
150+
child_right_position = get_child_right_position(curr_pos)
151+
if child_left_position < self.elements and child_right_position < self.elements:
152+
_, child_left_weight = self.heap[child_left_position]
153+
_, child_right_weight = self.heap[child_right_position]
154+
if child_right_weight < child_left_weight:
155+
if child_right_weight < weight:
156+
self._swap_nodes(child_right_position, curr_pos)
157+
return self._bubble_down(elem)
158+
if child_left_position < self.elements:
159+
_, child_left_weight = self.heap[child_left_position]
160+
if child_left_weight < weight:
161+
self._swap_nodes(child_left_position, curr_pos)
162+
return self._bubble_down(elem)
163+
else:
164+
return
165+
if child_right_position < self.elements:
166+
_, child_right_weight = self.heap[child_right_position]
167+
if child_right_weight < weight:
168+
self._swap_nodes(child_right_position, curr_pos)
169+
return self._bubble_down(elem)
170+
else:
171+
return
172+
173+
def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
174+
# Swap the nodes at the given positions
175+
node1_elem = self.heap[node1_pos][0]
176+
node2_elem = self.heap[node2_pos][0]
177+
self.heap[node1_pos], self.heap[node2_pos] = (
178+
self.heap[node2_pos],
179+
self.heap[node1_pos],
180+
)
181+
self.position_map[node1_elem] = node2_pos
182+
self.position_map[node2_elem] = node1_pos
183+
184+
185+
class GraphUndirectedWeighted:
186+
"""
187+
Graph Undirected Weighted Class
188+
189+
Functions:
190+
add_node: function to add a node in the graph
191+
add_edge: function to add an edge between 2 nodes in the graph
192+
"""
193+
194+
def __init__(self) -> None:
195+
self.connections = {}
196+
self.nodes = 0
197+
198+
def __repr__(self) -> str:
199+
return str(self.connections)
200+
201+
def __len__(self) -> int:
202+
return self.nodes
203+
204+
def add_node(self, node: Union[int, str]) -> None:
205+
# Add a node in the graph if it is not in the graph
206+
if node not in self.connections:
207+
self.connections[node] = {}
208+
self.nodes += 1
209+
210+
def add_edge(
211+
self, node1: Union[int, str], node2: Union[int, str], weight: int
212+
) -> None:
213+
# Add an edge between 2 nodes in the graph
214+
self.add_node(node1)
215+
self.add_node(node2)
216+
self.connections[node1][node2] = weight
217+
self.connections[node2][node1] = weight
218+
219+
220+
def prims_algo(
221+
graph: GraphUndirectedWeighted,
222+
) -> Tuple[Dict[str, int], Dict[str, Optional[str]]]:
223+
"""
224+
>>> graph = GraphUndirectedWeighted()
225+
226+
>>> graph.add_edge("a", "b", 3)
227+
>>> graph.add_edge("b", "c", 10)
228+
>>> graph.add_edge("c", "d", 5)
229+
>>> graph.add_edge("a", "c", 15)
230+
>>> graph.add_edge("b", "d", 100)
231+
232+
>>> dist, parent = prims_algo(graph)
233+
234+
>>> abs(dist["a"] - dist["b"])
235+
3
236+
>>> abs(dist["d"] - dist["b"])
237+
15
238+
>>> abs(dist["a"] - dist["c"])
239+
13
240+
"""
241+
# prim's algorithm for minimum spanning tree
242+
dist = {node: maxsize for node in graph.connections}
243+
parent = {node: None for node in graph.connections}
244+
priority_queue = MinPriorityQueue()
245+
[priority_queue.push(node, weight) for node, weight in dist.items()]
246+
if priority_queue.is_empty():
247+
return dist, parent
248+
249+
# initialization
250+
node = priority_queue.extract_min()
251+
dist[node] = 0
252+
for neighbour in graph.connections[node]:
253+
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
254+
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
255+
priority_queue.update_key(neighbour, dist[neighbour])
256+
parent[neighbour] = node
257+
# running prim's algorithm
258+
while not priority_queue.is_empty():
259+
node = priority_queue.extract_min()
260+
for neighbour in graph.connections[node]:
261+
if dist[neighbour] > dist[node] + graph.connections[node][neighbour]:
262+
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
263+
priority_queue.update_key(neighbour, dist[neighbour])
264+
parent[neighbour] = node
265+
return dist, parent
266+
267+
268+
if __name__ == "__main__":
269+
from doctest import testmod
270+
271+
testmod()

0 commit comments

Comments
 (0)