Skip to content

Commit 022eb49

Browse files
luoheng23stokhos
authored andcommitted
Add disjoint set (TheAlgorithms#1194)
* Add disjoint set * disjoint set: add doctest, make code more Pythonic * disjoint set: replace x.p with x.parent * disjoint set: add test and refercence
1 parent 9275a85 commit 022eb49

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
disjoint set
3+
Reference: https://en.wikipedia.org/wiki/Disjoint-set_data_structure
4+
"""
5+
6+
7+
class Node:
8+
def __init__(self, data):
9+
self.data = data
10+
11+
12+
def make_set(x):
13+
"""
14+
make x as a set.
15+
"""
16+
# rank is the distance from x to its' parent
17+
# root's rank is 0
18+
x.rank = 0
19+
x.parent = x
20+
21+
22+
def union_set(x, y):
23+
"""
24+
union two sets.
25+
set with bigger rank should be parent, so that the
26+
disjoint set tree will be more flat.
27+
"""
28+
x, y = find_set(x), find_set(y)
29+
if x.rank > y.rank:
30+
y.parent = x
31+
else:
32+
x.parent = y
33+
if x.rank == y.rank:
34+
y.rank += 1
35+
36+
37+
def find_set(x):
38+
"""
39+
return the parent of x
40+
"""
41+
if x != x.parent:
42+
x.parent = find_set(x.parent)
43+
return x.parent
44+
45+
46+
def find_python_set(node: Node) -> set:
47+
"""
48+
Return a Python Standard Library set that contains i.
49+
"""
50+
sets = ({0, 1, 2}, {3, 4, 5})
51+
for s in sets:
52+
if node.data in s:
53+
return s
54+
raise ValueError(f"{node.data} is not in {sets}")
55+
56+
57+
def test_disjoint_set():
58+
"""
59+
>>> test_disjoint_set()
60+
"""
61+
vertex = [Node(i) for i in range(6)]
62+
for v in vertex:
63+
make_set(v)
64+
65+
union_set(vertex[0], vertex[1])
66+
union_set(vertex[1], vertex[2])
67+
union_set(vertex[3], vertex[4])
68+
union_set(vertex[3], vertex[5])
69+
70+
for node0 in vertex:
71+
for node1 in vertex:
72+
if find_python_set(node0).isdisjoint(find_python_set(node1)):
73+
assert find_set(node0) != find_set(node1)
74+
else:
75+
assert find_set(node0) == find_set(node1)
76+
77+
78+
if __name__ == "__main__":
79+
test_disjoint_set()

0 commit comments

Comments
 (0)