Skip to content

Commit 1f2d607

Browse files
authored
Graphs : Bidirectional A* (TheAlgorithms#2015)
* implement bidirectional astar * add type hints * add wikipedia url * format with black * changes from review
1 parent 965d02a commit 1f2d607

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

graphs/bidirectional_a_star.py

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""
2+
https://en.wikipedia.org/wiki/Bidirectional_search
3+
"""
4+
5+
import time
6+
from typing import List, Tuple
7+
8+
grid = [
9+
[0, 0, 0, 0, 0, 0, 0],
10+
[0, 1, 0, 0, 0, 0, 0], # 0 are free path whereas 1's are obstacles
11+
[0, 0, 0, 0, 0, 0, 0],
12+
[0, 0, 1, 0, 0, 0, 0],
13+
[1, 0, 1, 0, 0, 0, 0],
14+
[0, 0, 0, 0, 0, 0, 0],
15+
[0, 0, 0, 0, 1, 0, 0],
16+
]
17+
18+
delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right
19+
20+
21+
class Node:
22+
"""
23+
>>> k = Node(0, 0, 4, 5, 0, None)
24+
>>> k.calculate_heuristic()
25+
9
26+
>>> n = Node(1, 4, 3, 4, 2, None)
27+
>>> n.calculate_heuristic()
28+
2
29+
>>> l = [k, n]
30+
>>> n == l[0]
31+
False
32+
>>> l.sort()
33+
>>> n == l[0]
34+
True
35+
"""
36+
37+
def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
38+
self.pos_x = pos_x
39+
self.pos_y = pos_y
40+
self.pos = (pos_y, pos_x)
41+
self.goal_x = goal_x
42+
self.goal_y = goal_y
43+
self.g_cost = g_cost
44+
self.parent = parent
45+
self.h_cost = self.calculate_heuristic()
46+
self.f_cost = self.g_cost + self.h_cost
47+
48+
def calculate_heuristic(self) -> float:
49+
"""
50+
The heuristic here is the Manhattan Distance
51+
Could elaborate to offer more than one choice
52+
"""
53+
dy = abs(self.pos_x - self.goal_x)
54+
dx = abs(self.pos_y - self.goal_y)
55+
return dx + dy
56+
57+
def __lt__(self, other):
58+
return self.f_cost < other.f_cost
59+
60+
61+
class AStar:
62+
def __init__(self, start, goal):
63+
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
64+
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
65+
66+
self.open_nodes = [self.start]
67+
self.closed_nodes = []
68+
69+
self.reached = False
70+
71+
self.path = [(self.start.pos_y, self.start.pos_x)]
72+
self.costs = [0]
73+
74+
def search(self):
75+
while self.open_nodes:
76+
# Open Nodes are sorted using __lt__
77+
self.open_nodes.sort()
78+
current_node = self.open_nodes.pop(0)
79+
80+
if current_node.pos == self.target.pos:
81+
self.reached = True
82+
self.path = self.retrace_path(current_node)
83+
break
84+
85+
self.closed_nodes.append(current_node)
86+
successors = self.get_successors(current_node)
87+
88+
for child_node in successors:
89+
if child_node in self.closed_nodes:
90+
continue
91+
92+
if child_node not in self.open_nodes:
93+
self.open_nodes.append(child_node)
94+
else:
95+
# retrieve the best current path
96+
better_node = self.open_nodes.pop(self.open_nodes.index(child_node))
97+
98+
if child_node.g_cost < better_node.g_cost:
99+
self.open_nodes.append(child_node)
100+
else:
101+
self.open_nodes.append(better_node)
102+
103+
if not (self.reached):
104+
print("No path found")
105+
106+
def get_successors(self, parent: Node) -> List[Node]:
107+
"""
108+
Returns a list of successors (both in the grid and free spaces)
109+
"""
110+
successors = []
111+
for action in delta:
112+
pos_x = parent.pos_x + action[1]
113+
pos_y = parent.pos_y + action[0]
114+
if not (0 < pos_x < len(grid[0]) - 1 and 0 < pos_y < len(grid) - 1):
115+
continue
116+
117+
if grid[pos_y][pos_x] != 0:
118+
continue
119+
120+
node_ = Node(
121+
pos_x,
122+
pos_y,
123+
self.target.pos_y,
124+
self.target.pos_x,
125+
parent.g_cost + 1,
126+
parent,
127+
)
128+
successors.append(node_)
129+
return successors
130+
131+
def retrace_path(self, node: Node) -> List[Tuple[int]]:
132+
"""
133+
Retrace the path from parents to parents until start node
134+
"""
135+
current_node = node
136+
path = []
137+
while current_node is not None:
138+
path.append((current_node.pos_y, current_node.pos_x))
139+
current_node = current_node.parent
140+
path.reverse()
141+
return path
142+
143+
144+
class BidirectionalAStar:
145+
def __init__(self, start, goal):
146+
self.fwd_astar = AStar(start, goal)
147+
self.bwd_astar = AStar(goal, start)
148+
self.reached = False
149+
self.path = self.fwd_astar.path
150+
151+
def search(self):
152+
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
153+
self.fwd_astar.open_nodes.sort()
154+
self.bwd_astar.open_nodes.sort()
155+
current_fwd_node = self.fwd_astar.open_nodes.pop(0)
156+
current_bwd_node = self.bwd_astar.open_nodes.pop(0)
157+
158+
if current_bwd_node.pos == current_fwd_node.pos:
159+
self.reached = True
160+
self.retrace_bidirectional_path(current_fwd_node, current_bwd_node)
161+
break
162+
163+
self.fwd_astar.closed_nodes.append(current_fwd_node)
164+
self.bwd_astar.closed_nodes.append(current_bwd_node)
165+
166+
self.fwd_astar.target = current_bwd_node
167+
self.bwd_astar.target = current_fwd_node
168+
169+
successors = {
170+
self.fwd_astar: self.fwd_astar.get_successors(current_fwd_node),
171+
self.bwd_astar: self.bwd_astar.get_successors(current_bwd_node),
172+
}
173+
174+
for astar in [self.fwd_astar, self.bwd_astar]:
175+
for child_node in successors[astar]:
176+
if child_node in astar.closed_nodes:
177+
continue
178+
179+
if child_node not in astar.open_nodes:
180+
astar.open_nodes.append(child_node)
181+
else:
182+
# retrieve the best current path
183+
better_node = astar.open_nodes.pop(
184+
astar.open_nodes.index(child_node)
185+
)
186+
187+
if child_node.g_cost < better_node.g_cost:
188+
astar.open_nodes.append(child_node)
189+
else:
190+
astar.open_nodes.append(better_node)
191+
192+
def retrace_bidirectional_path(
193+
self, fwd_node: Node, bwd_node: Node
194+
) -> List[Tuple[int]]:
195+
fwd_path = self.fwd_astar.retrace_path(fwd_node)
196+
bwd_path = self.bwd_astar.retrace_path(bwd_node)
197+
fwd_path.reverse()
198+
path = fwd_path + bwd_path
199+
return path
200+
201+
202+
# all coordinates are given in format [y,x]
203+
init = (0, 0)
204+
goal = (len(grid) - 1, len(grid[0]) - 1)
205+
for elem in grid:
206+
print(elem)
207+
208+
start_time = time.time()
209+
a_star = AStar(init, goal)
210+
a_star.search()
211+
end_time = time.time() - start_time
212+
print(f"AStar execution time = {end_time:f} seconds")
213+
214+
bd_start_time = time.time()
215+
bidir_astar = BidirectionalAStar(init, goal)
216+
bidir_astar.search()
217+
bd_end_time = time.time() - bd_start_time
218+
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")

0 commit comments

Comments
 (0)