3
3
"""
4
4
5
5
import time
6
+ from math import sqrt
6
7
from typing import List , Tuple
7
8
9
+ # 1 for manhattan, 0 for euclidean
10
+ HEURISTIC = 0
11
+
8
12
grid = [
9
13
[0 , 0 , 0 , 0 , 0 , 0 , 0 ],
10
14
[0 , 1 , 0 , 0 , 0 , 0 , 0 ], # 0 are free path whereas 1's are obstacles
20
24
21
25
class Node :
22
26
"""
23
- >>> k = Node(0, 0, 4, 5 , 0, None)
27
+ >>> k = Node(0, 0, 4, 3 , 0, None)
24
28
>>> k.calculate_heuristic()
25
- 9
29
+ 5.0
26
30
>>> n = Node(1, 4, 3, 4, 2, None)
27
31
>>> n.calculate_heuristic()
28
- 2
32
+ 2.0
29
33
>>> l = [k, n]
30
34
>>> n == l[0]
31
35
False
@@ -47,18 +51,35 @@ def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
47
51
48
52
def calculate_heuristic (self ) -> float :
49
53
"""
50
- The heuristic here is the Manhattan Distance
51
- Could elaborate to offer more than one choice
54
+ Heuristic for the A*
52
55
"""
53
- dy = abs (self .pos_x - self .goal_x )
54
- dx = abs (self .pos_y - self .goal_y )
55
- return dx + dy
56
-
57
- def __lt__ (self , other ):
56
+ dy = self .pos_x - self .goal_x
57
+ dx = self .pos_y - self .goal_y
58
+ if HEURISTIC == 1 :
59
+ return abs (dx ) + abs (dy )
60
+ else :
61
+ return sqrt (dy ** 2 + dx ** 2 )
62
+
63
+ def __lt__ (self , other ) -> bool :
58
64
return self .f_cost < other .f_cost
59
65
60
66
61
67
class AStar :
68
+ """
69
+ >>> astar = AStar((0, 0), (len(grid) - 1, len(grid[0]) - 1))
70
+ >>> (astar.start.pos_y + delta[3][0], astar.start.pos_x + delta[3][1])
71
+ (0, 1)
72
+ >>> [x.pos for x in astar.get_successors(astar.start)]
73
+ [(1, 0), (0, 1)]
74
+ >>> (astar.start.pos_y + delta[2][0], astar.start.pos_x + delta[2][1])
75
+ (1, 0)
76
+ >>> astar.retrace_path(astar.start)
77
+ [(0, 0)]
78
+ >>> astar.search() # doctest: +NORMALIZE_WHITESPACE
79
+ [(0, 0), (1, 0), (2, 0), (2, 1), (2, 2), (2, 3), (3, 3),
80
+ (4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
81
+ """
82
+
62
83
def __init__ (self , start , goal ):
63
84
self .start = Node (start [1 ], start [0 ], goal [1 ], goal [0 ], 0 , None )
64
85
self .target = Node (goal [1 ], goal [0 ], goal [1 ], goal [0 ], 99999 , None )
@@ -68,19 +89,15 @@ def __init__(self, start, goal):
68
89
69
90
self .reached = False
70
91
71
- self .path = [(self .start .pos_y , self .start .pos_x )]
72
- self .costs = [0 ]
73
-
74
- def search (self ):
92
+ def search (self ) -> List [Tuple [int ]]:
75
93
while self .open_nodes :
76
94
# Open Nodes are sorted using __lt__
77
95
self .open_nodes .sort ()
78
96
current_node = self .open_nodes .pop (0 )
79
97
80
98
if current_node .pos == self .target .pos :
81
99
self .reached = True
82
- self .path = self .retrace_path (current_node )
83
- break
100
+ return self .retrace_path (current_node )
84
101
85
102
self .closed_nodes .append (current_node )
86
103
successors = self .get_successors (current_node )
@@ -101,7 +118,7 @@ def search(self):
101
118
self .open_nodes .append (better_node )
102
119
103
120
if not (self .reached ):
104
- print ( "No path found" )
121
+ return [( self . start . pos )]
105
122
106
123
def get_successors (self , parent : Node ) -> List [Node ]:
107
124
"""
@@ -111,21 +128,22 @@ def get_successors(self, parent: Node) -> List[Node]:
111
128
for action in delta :
112
129
pos_x = parent .pos_x + action [1 ]
113
130
pos_y = parent .pos_y + action [0 ]
114
- if not (0 < pos_x < len (grid [0 ]) - 1 and 0 < pos_y < len (grid ) - 1 ):
131
+ if not (0 <= pos_x <= len (grid [0 ]) - 1 and 0 <= pos_y <= len (grid ) - 1 ):
115
132
continue
116
133
117
134
if grid [pos_y ][pos_x ] != 0 :
118
135
continue
119
136
120
- node_ = Node (
121
- pos_x ,
122
- pos_y ,
123
- self .target .pos_y ,
124
- self .target .pos_x ,
125
- parent .g_cost + 1 ,
126
- parent ,
137
+ successors .append (
138
+ Node (
139
+ pos_x ,
140
+ pos_y ,
141
+ self .target .pos_y ,
142
+ self .target .pos_x ,
143
+ parent .g_cost + 1 ,
144
+ parent ,
145
+ )
127
146
)
128
- successors .append (node_ )
129
147
return successors
130
148
131
149
def retrace_path (self , node : Node ) -> List [Tuple [int ]]:
@@ -142,13 +160,24 @@ def retrace_path(self, node: Node) -> List[Tuple[int]]:
142
160
143
161
144
162
class BidirectionalAStar :
163
+ """
164
+ >>> bd_astar = BidirectionalAStar((0, 0), (len(grid) - 1, len(grid[0]) - 1))
165
+ >>> bd_astar.fwd_astar.start.pos == bd_astar.bwd_astar.target.pos
166
+ True
167
+ >>> bd_astar.retrace_bidirectional_path(bd_astar.fwd_astar.start,
168
+ ... bd_astar.bwd_astar.start)
169
+ [(0, 0)]
170
+ >>> bd_astar.search() # doctest: +NORMALIZE_WHITESPACE
171
+ [(0, 0), (0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4),
172
+ (2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
173
+ """
174
+
145
175
def __init__ (self , start , goal ):
146
176
self .fwd_astar = AStar (start , goal )
147
177
self .bwd_astar = AStar (goal , start )
148
178
self .reached = False
149
- self .path = self .fwd_astar .path
150
179
151
- def search (self ):
180
+ def search (self ) -> List [ Tuple [ int ]] :
152
181
while self .fwd_astar .open_nodes or self .bwd_astar .open_nodes :
153
182
self .fwd_astar .open_nodes .sort ()
154
183
self .bwd_astar .open_nodes .sort ()
@@ -157,8 +186,9 @@ def search(self):
157
186
158
187
if current_bwd_node .pos == current_fwd_node .pos :
159
188
self .reached = True
160
- self .retrace_bidirectional_path (current_fwd_node , current_bwd_node )
161
- break
189
+ return self .retrace_bidirectional_path (
190
+ current_fwd_node , current_bwd_node
191
+ )
162
192
163
193
self .fwd_astar .closed_nodes .append (current_fwd_node )
164
194
self .bwd_astar .closed_nodes .append (current_bwd_node )
@@ -189,30 +219,38 @@ def search(self):
189
219
else :
190
220
astar .open_nodes .append (better_node )
191
221
222
+ if not self .reached :
223
+ return [self .fwd_astar .start .pos ]
224
+
192
225
def retrace_bidirectional_path (
193
226
self , fwd_node : Node , bwd_node : Node
194
227
) -> List [Tuple [int ]]:
195
228
fwd_path = self .fwd_astar .retrace_path (fwd_node )
196
229
bwd_path = self .bwd_astar .retrace_path (bwd_node )
197
- fwd_path .reverse ()
230
+ bwd_path .pop ()
231
+ bwd_path .reverse ()
198
232
path = fwd_path + bwd_path
199
233
return path
200
234
201
235
202
- # all coordinates are given in format [y,x]
203
- init = (0 , 0 )
204
- goal = (len (grid ) - 1 , len (grid [0 ]) - 1 )
205
- for elem in grid :
206
- print (elem )
236
+ if __name__ == "__main__" :
237
+ # all coordinates are given in format [y,x]
238
+ import doctest
239
+
240
+ doctest .testmod ()
241
+ init = (0 , 0 )
242
+ goal = (len (grid ) - 1 , len (grid [0 ]) - 1 )
243
+ for elem in grid :
244
+ print (elem )
207
245
208
- start_time = time .time ()
209
- a_star = AStar (init , goal )
210
- a_star .search ()
211
- end_time = time .time () - start_time
212
- print (f"AStar execution time = { end_time :f} seconds" )
246
+ start_time = time .time ()
247
+ a_star = AStar (init , goal )
248
+ path = a_star .search ()
249
+ end_time = time .time () - start_time
250
+ print (f"AStar execution time = { end_time :f} seconds" )
213
251
214
- bd_start_time = time .time ()
215
- bidir_astar = BidirectionalAStar (init , goal )
216
- bidir_astar .search ()
217
- bd_end_time = time .time () - bd_start_time
218
- print (f"BidirectionalAStar execution time = { bd_end_time :f} seconds" )
252
+ bd_start_time = time .time ()
253
+ bidir_astar = BidirectionalAStar (init , goal )
254
+ path = bidir_astar .search ()
255
+ bd_end_time = time .time () - bd_start_time
256
+ print (f"BidirectionalAStar execution time = { bd_end_time :f} seconds" )
0 commit comments