Skip to content

Commit 3cd9201

Browse files
committed
Correct implementation and add tests for dfs and bfs
1 parent 40f65e8 commit 3cd9201

File tree

1 file changed

+193
-63
lines changed

1 file changed

+193
-63
lines changed

graphs/directed_and_undirected_weighted_graph.py

+193-63
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ def __init__(self):
1414
# adding the weight is optional
1515
# handles repetition
1616
def add_pair(self, u, v, w=1):
17+
"""
18+
Adds a directed edge u->v with weight w.
19+
>>> dg = DirectedGraph()
20+
>>> dg.add_pair(-1,2)
21+
>>> dg.add_pair(1,3,5)
22+
>>> dg.add_pair(1,3,5)
23+
>>> dg.add_pair(1,3,6)
24+
>>> dg.all_nodes()
25+
[-1, 2, 1, 3]
26+
>>> dg.graph[1]
27+
[[5, 3], [6, 3]]
28+
"""
1729
if self.graph.get(u):
1830
if self.graph[u].count([w, v]) == 0:
1931
self.graph[u].append([w, v])
@@ -23,53 +35,94 @@ def add_pair(self, u, v, w=1):
2335
self.graph[v] = []
2436

2537
def all_nodes(self):
38+
"""
39+
Returns list of all nodes in the graph.
40+
>>> dg = DirectedGraph()
41+
>>> dg.all_nodes()
42+
[]
43+
>>> dg.add_pair(1,1)
44+
>>> dg.all_nodes()
45+
[1]
46+
>>> dg.add_pair(2,3,3)
47+
>>> dg.all_nodes()
48+
[1, 2, 3]
49+
"""
2650
return list(self.graph)
2751

2852
# handles if the input does not exist
2953
def remove_pair(self, u, v):
54+
"""
55+
Removes all edges u->v if it exists.
56+
>>> dg = DirectedGraph()
57+
>>> dg.remove_pair(1,2) # silently exits
58+
>>> dg.add_pair(0,5,2)
59+
>>> dg.graph[0]
60+
[[2, 5]]
61+
>>> dg.remove_pair(5,0)
62+
>>> dg.graph[0]
63+
[[2, 5]]
64+
>>> dg.remove_pair(0,5)
65+
>>> dg.graph[0]
66+
[]
67+
"""
3068
if self.graph.get(u):
3169
for _ in self.graph[u]:
3270
if _[1] == v:
3371
self.graph[u].remove(_)
3472

3573
# if no destination is meant the default value is -1
3674
def dfs(self, s=-2, d=-1):
37-
if s == d:
38-
return []
75+
"""
76+
Performs depth first search from s to find d.
77+
Returns the path s->d as a list.
78+
Returns dfs from s if d is not found
79+
>>> dg = DirectedGraph()
80+
>>> dg.dfs()
81+
[]
82+
>>> dg.add_pair(1,1)
83+
>>> dg.dfs(1,1)
84+
[1]
85+
>>> dg = DirectedGraph()
86+
>>> dg.add_pair(0,1)
87+
>>> dg.add_pair(0,2)
88+
>>> dg.add_pair(1,3)
89+
>>> dg.add_pair(1,4)
90+
>>> dg.add_pair(1,5)
91+
>>> dg.add_pair(2,5)
92+
>>> dg.add_pair(5,6)
93+
>>> dg.dfs(0,6)
94+
[0, 2, 5, 6]
95+
>>> dg.dfs(1,6)
96+
[1, 5, 6]
97+
>>> dg.dfs()
98+
[0, 2, 5, 6, 1, 4, 3]
99+
>>> dg.dfs(1,0)
100+
[1, 5, 6, 4, 3]
101+
"""
39102
stack = []
40103
visited = []
41104
if s == -2:
42-
s = next(iter(self.graph))
43-
stack.append(s)
44-
visited.append(s)
45-
ss = s
46-
47-
while True:
48-
# check if there is any non isolated nodes
49-
if len(self.graph[s]) != 0:
50-
ss = s
51-
for node in self.graph[s]:
52-
if visited.count(node[1]) < 1:
53-
if node[1] == d:
54-
visited.append(d)
55-
return visited
56-
else:
57-
stack.append(node[1])
58-
visited.append(node[1])
59-
ss = node[1]
60-
break
61-
62-
# check if all the children are visited
63-
if s == ss:
64-
stack.pop()
65-
if len(stack) != 0:
66-
s = stack[len(stack) - 1]
105+
if self.graph.get(s,None):
106+
pass # -2 is a node
107+
elif len(self.graph) > 0:
108+
s = next(iter(self.graph))
67109
else:
68-
s = ss
110+
return [] # Graph empty
111+
stack.append(s)
69112

70-
# check if se have reached the starting point
71-
if len(stack) == 0:
72-
return visited
113+
# Run dfs
114+
while len(stack) > 0:
115+
s = stack.pop()
116+
visited.append(s)
117+
# If reached d, return
118+
if s==d:
119+
break
120+
121+
# add not visited child nodes to stack
122+
for _,ss in self.graph[s]:
123+
if visited.count(ss) < 1:
124+
stack.append(ss)
125+
return visited
73126

74127
# c is the count of nodes you want and if you leave it or pass -1 to the function
75128
# the count will be random from 10 to 10000
@@ -84,12 +137,42 @@ def fill_graph_randomly(self, c=-1):
84137
self.add_pair(i, n, 1)
85138

86139
def bfs(self, s=-2):
140+
"""
141+
Performs breadth first search from s
142+
Returns list.
143+
>>> dg = DirectedGraph()
144+
>>> dg.bfs()
145+
[]
146+
>>> dg.add_pair(1,1)
147+
>>> dg.bfs(1)
148+
[1]
149+
>>> dg = DirectedGraph()
150+
>>> dg.add_pair(0,1)
151+
>>> dg.add_pair(0,2)
152+
>>> dg.add_pair(1,3)
153+
>>> dg.add_pair(1,4)
154+
>>> dg.add_pair(1,5)
155+
>>> dg.add_pair(2,5)
156+
>>> dg.add_pair(5,6)
157+
>>> dg.bfs(0)
158+
[0, 1, 2, 3, 4, 5, 6]
159+
>>> dg.bfs(1)
160+
[1, 3, 4, 5, 6]
161+
>>> dg.bfs()
162+
[0, 1, 2, 3, 4, 5, 6]
163+
"""
87164
d = deque()
88165
visited = []
89166
if s == -2:
90-
s = next(iter(self.graph))
167+
if self.graph.get(s,None):
168+
pass # -2 is a node
169+
elif len(self.graph) > 0:
170+
s = next(iter(self.graph))
171+
else:
172+
return [] # Graph empty
91173
d.append(s)
92174
visited.append(s)
175+
# Run bfs
93176
while d:
94177
s = d.popleft()
95178
if len(self.graph[s]) != 0:
@@ -300,42 +383,60 @@ def remove_pair(self, u, v):
300383

301384
# if no destination is meant the default value is -1
302385
def dfs(self, s=-2, d=-1):
303-
if s == d:
304-
return []
386+
"""
387+
Performs depth first search from s to find d.
388+
Returns the path s->d as a list.
389+
Returns dfs from s if d is not found
390+
>>> ug = Graph()
391+
>>> ug.dfs()
392+
[]
393+
>>> ug.add_pair(1,1)
394+
>>> ug.dfs(1,1)
395+
[1]
396+
>>> ug = Graph()
397+
>>> ug.add_pair(0,1)
398+
>>> ug.add_pair(0,2)
399+
>>> ug.add_pair(1,3)
400+
>>> ug.add_pair(1,4)
401+
>>> ug.add_pair(1,5)
402+
>>> ug.add_pair(2,5)
403+
>>> ug.add_pair(5,6)
404+
>>> ug.dfs(0,6)
405+
[0, 2, 5, 6]
406+
>>> ug.dfs(1,6)
407+
[1, 5, 6]
408+
>>> ug.dfs()
409+
[0, 2, 5, 6, 1, 4, 3]
410+
>>> ug.dfs(1,0)
411+
[1, 5, 6, 2, 0]
412+
"""
305413
stack = []
306414
visited = []
307415
if s == -2:
308-
s = next(iter(self.graph))
416+
if self.graph.get(s,None):
417+
pass # -2 is a node
418+
elif len(self.graph) > 0:
419+
s = next(iter(self.graph))
420+
else:
421+
return [] # Graph empty
309422
stack.append(s)
310-
visited.append(s)
311-
ss = s
312-
313-
while True:
314-
# check if there is any non isolated nodes
315-
if len(self.graph[s]) != 0:
316-
ss = s
317-
for node in self.graph[s]:
318-
if visited.count(node[1]) < 1:
319-
if node[1] == d:
320-
visited.append(d)
321-
return visited
322-
else:
323-
stack.append(node[1])
324-
visited.append(node[1])
325-
ss = node[1]
326-
break
327423

328-
# check if all the children are visited
329-
if s == ss:
330-
stack.pop()
331-
if len(stack) != 0:
332-
s = stack[len(stack) - 1]
424+
# Run dfs
425+
while len(stack) > 0:
426+
s = stack.pop()
427+
if visited.count(s) == 1:
428+
continue
333429
else:
334-
s = ss
335-
336-
# check if se have reached the starting point
337-
if len(stack) == 0:
338-
return visited
430+
visited.append(s)
431+
# If reached d, return
432+
if s==d:
433+
break
434+
435+
# add not visited child nodes to stack
436+
for _,ss in self.graph[s]:
437+
if visited.count(ss) < 1:
438+
stack.append(ss)
439+
return visited
339440

340441
# c is the count of nodes you want and if you leave it or pass -1 to the function
341442
# the count will be random from 10 to 10000
@@ -350,10 +451,39 @@ def fill_graph_randomly(self, c=-1):
350451
self.add_pair(i, n, 1)
351452

352453
def bfs(self, s=-2):
454+
"""
455+
Performs breadth first search from s
456+
Returns list.
457+
>>> ug = Graph()
458+
>>> ug.bfs()
459+
[]
460+
>>> ug.add_pair(1,1)
461+
>>> ug.bfs(1)
462+
[1]
463+
>>> ug = Graph()
464+
>>> ug.add_pair(0,1)
465+
>>> ug.add_pair(0,2)
466+
>>> ug.add_pair(1,3)
467+
>>> ug.add_pair(1,4)
468+
>>> ug.add_pair(1,5)
469+
>>> ug.add_pair(2,5)
470+
>>> ug.add_pair(5,6)
471+
>>> ug.bfs(0)
472+
[0, 1, 2, 3, 4, 5, 6]
473+
>>> ug.bfs(1)
474+
[1, 0, 3, 4, 5, 2, 6]
475+
>>> ug.bfs()
476+
[0, 1, 2, 3, 4, 5, 6]
477+
"""
353478
d = deque()
354479
visited = []
355480
if s == -2:
356-
s = next(iter(self.graph))
481+
if self.graph.get(s,None):
482+
pass # -2 is a node
483+
elif len(self.graph) > 0:
484+
s = next(iter(self.graph))
485+
else:
486+
return [] # Graph empty
357487
d.append(s)
358488
visited.append(s)
359489
while d:

0 commit comments

Comments
 (0)