Skip to content

Commit cecf1fd

Browse files
JadeKim042386pre-commit-ci[bot]tianyizheng02
authored
Fix greedy_best_first (#8775)
* fix: typo #8770 * refactor: delete unnecessary continue * add test grids * fix: add \_\_eq\_\_ in Node class #8770 * fix: delete unnecessary code - node in self.open_nodes is always better node #8770 * fix: docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: docstring max length * refactor: get the successors using a list comprehension * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tianyi Zheng <[email protected]>
1 parent 490e645 commit cecf1fd

File tree

1 file changed

+67
-53
lines changed

1 file changed

+67
-53
lines changed

Diff for: graphs/greedy_best_first.py

+67-53
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,32 @@
66

77
Path = list[tuple[int, int]]
88

9-
grid = [
10-
[0, 0, 0, 0, 0, 0, 0],
11-
[0, 1, 0, 0, 0, 0, 0], # 0 are free path whereas 1's are obstacles
12-
[0, 0, 0, 0, 0, 0, 0],
13-
[0, 0, 1, 0, 0, 0, 0],
14-
[1, 0, 1, 0, 0, 0, 0],
15-
[0, 0, 0, 0, 0, 0, 0],
16-
[0, 0, 0, 0, 1, 0, 0],
9+
# 0's are free path whereas 1's are obstacles
10+
TEST_GRIDS = [
11+
[
12+
[0, 0, 0, 0, 0, 0, 0],
13+
[0, 1, 0, 0, 0, 0, 0],
14+
[0, 0, 0, 0, 0, 0, 0],
15+
[0, 0, 1, 0, 0, 0, 0],
16+
[1, 0, 1, 0, 0, 0, 0],
17+
[0, 0, 0, 0, 0, 0, 0],
18+
[0, 0, 0, 0, 1, 0, 0],
19+
],
20+
[
21+
[0, 0, 0, 1, 1, 0, 0],
22+
[0, 0, 0, 0, 1, 0, 1],
23+
[0, 0, 0, 1, 1, 0, 0],
24+
[0, 1, 0, 0, 1, 0, 0],
25+
[1, 0, 0, 1, 1, 0, 1],
26+
[0, 0, 0, 0, 0, 0, 0],
27+
],
28+
[
29+
[0, 0, 1, 0, 0],
30+
[0, 1, 0, 0, 0],
31+
[0, 0, 1, 0, 1],
32+
[1, 0, 0, 1, 1],
33+
[0, 0, 0, 0, 0],
34+
],
1735
]
1836

1937
delta = ([-1, 0], [0, -1], [1, 0], [0, 1]) # up, left, down, right
@@ -65,10 +83,14 @@ def calculate_heuristic(self) -> float:
6583
def __lt__(self, other) -> bool:
6684
return self.f_cost < other.f_cost
6785

86+
def __eq__(self, other) -> bool:
87+
return self.pos == other.pos
88+
6889

6990
class GreedyBestFirst:
7091
"""
71-
>>> gbf = GreedyBestFirst((0, 0), (len(grid) - 1, len(grid[0]) - 1))
92+
>>> grid = TEST_GRIDS[2]
93+
>>> gbf = GreedyBestFirst(grid, (0, 0), (len(grid) - 1, len(grid[0]) - 1))
7294
>>> [x.pos for x in gbf.get_successors(gbf.start)]
7395
[(1, 0), (0, 1)]
7496
>>> (gbf.start.pos_y + delta[3][0], gbf.start.pos_x + delta[3][1])
@@ -78,11 +100,14 @@ class GreedyBestFirst:
78100
>>> gbf.retrace_path(gbf.start)
79101
[(0, 0)]
80102
>>> gbf.search() # doctest: +NORMALIZE_WHITESPACE
81-
[(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (4, 1), (5, 1), (6, 1),
82-
(6, 2), (6, 3), (5, 3), (5, 4), (5, 5), (6, 5), (6, 6)]
103+
[(0, 0), (1, 0), (2, 0), (2, 1), (3, 1), (4, 1), (4, 2), (4, 3),
104+
(4, 4)]
83105
"""
84106

85-
def __init__(self, start: tuple[int, int], goal: tuple[int, int]):
107+
def __init__(
108+
self, grid: list[list[int]], start: tuple[int, int], goal: tuple[int, int]
109+
):
110+
self.grid = grid
86111
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
87112
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
88113

@@ -114,14 +139,6 @@ def search(self) -> Path | None:
114139

115140
if child_node not in self.open_nodes:
116141
self.open_nodes.append(child_node)
117-
else:
118-
# retrieve the best current path
119-
better_node = self.open_nodes.pop(self.open_nodes.index(child_node))
120-
121-
if child_node.g_cost < better_node.g_cost:
122-
self.open_nodes.append(child_node)
123-
else:
124-
self.open_nodes.append(better_node)
125142

126143
if not self.reached:
127144
return [self.start.pos]
@@ -131,28 +148,22 @@ def get_successors(self, parent: Node) -> list[Node]:
131148
"""
132149
Returns a list of successors (both in the grid and free spaces)
133150
"""
134-
successors = []
135-
for action in delta:
136-
pos_x = parent.pos_x + action[1]
137-
pos_y = parent.pos_y + action[0]
138-
139-
if not (0 <= pos_x <= len(grid[0]) - 1 and 0 <= pos_y <= len(grid) - 1):
140-
continue
141-
142-
if grid[pos_y][pos_x] != 0:
143-
continue
144-
145-
successors.append(
146-
Node(
147-
pos_x,
148-
pos_y,
149-
self.target.pos_y,
150-
self.target.pos_x,
151-
parent.g_cost + 1,
152-
parent,
153-
)
151+
return [
152+
Node(
153+
pos_x,
154+
pos_y,
155+
self.target.pos_x,
156+
self.target.pos_y,
157+
parent.g_cost + 1,
158+
parent,
159+
)
160+
for action in delta
161+
if (
162+
0 <= (pos_x := parent.pos_x + action[1]) < len(self.grid[0])
163+
and 0 <= (pos_y := parent.pos_y + action[0]) < len(self.grid)
164+
and self.grid[pos_y][pos_x] == 0
154165
)
155-
return successors
166+
]
156167

157168
def retrace_path(self, node: Node | None) -> Path:
158169
"""
@@ -168,18 +179,21 @@ def retrace_path(self, node: Node | None) -> Path:
168179

169180

170181
if __name__ == "__main__":
171-
init = (0, 0)
172-
goal = (len(grid) - 1, len(grid[0]) - 1)
173-
for elem in grid:
174-
print(elem)
175-
176-
print("------")
177-
178-
greedy_bf = GreedyBestFirst(init, goal)
179-
path = greedy_bf.search()
180-
if path:
181-
for pos_x, pos_y in path:
182-
grid[pos_x][pos_y] = 2
182+
for idx, grid in enumerate(TEST_GRIDS):
183+
print(f"==grid-{idx + 1}==")
183184

185+
init = (0, 0)
186+
goal = (len(grid) - 1, len(grid[0]) - 1)
184187
for elem in grid:
185188
print(elem)
189+
190+
print("------")
191+
192+
greedy_bf = GreedyBestFirst(grid, init, goal)
193+
path = greedy_bf.search()
194+
if path:
195+
for pos_x, pos_y in path:
196+
grid[pos_x][pos_y] = 2
197+
198+
for elem in grid:
199+
print(elem)

0 commit comments

Comments
 (0)