Skip to content

Commit 7180f08

Browse files
Created Union-Find algorithm
1 parent 8e6db7a commit 7180f08

File tree

3 files changed

+164
-0
lines changed

3 files changed

+164
-0
lines changed

Diff for: data_structures/UnionFind/__init__.py

Whitespace-only changes.

Diff for: data_structures/UnionFind/tests_union_find.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from union_find import UnionFind
2+
import unittest
3+
4+
5+
class TestUnionFind(unittest.TestCase):
6+
def test_init_with_valid_size(self):
7+
uf = UnionFind(5)
8+
self.assertEqual(uf.size, 5)
9+
10+
def test_init_with_invalid_size(self):
11+
with self.assertRaises(ValueError):
12+
uf = UnionFind(0)
13+
14+
with self.assertRaises(ValueError):
15+
uf = UnionFind(-5)
16+
17+
def test_union_with_valid_values(self):
18+
uf = UnionFind(10)
19+
20+
for i in range(11):
21+
for j in range(11):
22+
uf.union(i, j)
23+
24+
def test_union_with_invalid_values(self):
25+
uf = UnionFind(10)
26+
27+
with self.assertRaises(ValueError):
28+
uf.union(-1, 1)
29+
30+
with self.assertRaises(ValueError):
31+
uf.union(11, 1)
32+
33+
def test_same_set_with_valid_values(self):
34+
uf = UnionFind(10)
35+
36+
for i in range(11):
37+
for j in range(11):
38+
if i == j:
39+
self.assertTrue(uf.same_set(i, j))
40+
else:
41+
self.assertFalse(uf.same_set(i, j))
42+
43+
uf.union(1, 2)
44+
self.assertTrue(uf.same_set(1, 2))
45+
46+
uf.union(3, 4)
47+
self.assertTrue(uf.same_set(3, 4))
48+
49+
self.assertFalse(uf.same_set(1, 3))
50+
self.assertFalse(uf.same_set(1, 4))
51+
self.assertFalse(uf.same_set(2, 3))
52+
self.assertFalse(uf.same_set(2, 4))
53+
54+
uf.union(1, 3)
55+
self.assertTrue(uf.same_set(1, 3))
56+
self.assertTrue(uf.same_set(1, 4))
57+
self.assertTrue(uf.same_set(2, 3))
58+
self.assertTrue(uf.same_set(2, 4))
59+
60+
uf.union(4, 10)
61+
self.assertTrue(uf.same_set(1, 10))
62+
self.assertTrue(uf.same_set(2, 10))
63+
self.assertTrue(uf.same_set(3, 10))
64+
self.assertTrue(uf.same_set(4, 10))
65+
66+
def test_same_set_with_invalid_values(self):
67+
uf = UnionFind(10)
68+
69+
with self.assertRaises(ValueError):
70+
uf.same_set(-1, 1)
71+
72+
with self.assertRaises(ValueError):
73+
uf.same_set(11, 0)
74+
75+
76+
if __name__ == '__main__':
77+
unittest.main()

Diff for: data_structures/UnionFind/union_find.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
class UnionFind():
2+
"""
3+
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
4+
5+
The union-find is a disjoint-set data structure
6+
7+
You can merge two sets and tell if one set belongs to
8+
another one.
9+
10+
It's used on the Kruskal Algorithm
11+
(https://en.wikipedia.org/wiki/Kruskal%27s_algorithm)
12+
13+
The elements are in range [0, size]
14+
"""
15+
def __init__(self, size):
16+
if size <= 0:
17+
raise ValueError("size should be greater than 0")
18+
19+
self.size = size
20+
21+
# The below plus 1 is because we are using elements
22+
# in range [0, size]. It makes more sense.
23+
24+
# Every set begins with only itself
25+
self.root = [i for i in range(size+1)]
26+
27+
# This is used for heuristic union by rank
28+
self.weight = [0 for i in range(size+1)]
29+
30+
def union(self, u, v):
31+
"""
32+
Union of the sets u and v.
33+
Complexity: log(n).
34+
Amortized complexity: < 5 (it's very fast).
35+
"""
36+
37+
self._validate_element_range(u, "u")
38+
self._validate_element_range(v, "v")
39+
40+
if u == v:
41+
return
42+
43+
# Using union by rank will guarantee the
44+
# log(n) complexity
45+
rootu = self._root(u)
46+
rootv = self._root(v)
47+
weight_u = self.weight[rootu]
48+
weight_v = self.weight[rootv]
49+
if weight_u >= weight_v:
50+
self.root[rootv] = rootu
51+
if weight_u == weight_v:
52+
self.weight[rootu] += 1
53+
else:
54+
self.root[rootu] = rootv
55+
56+
def same_set(self, u, v):
57+
"""
58+
Return true if the elements u and v belongs to
59+
the same set
60+
"""
61+
62+
self._validate_element_range(u, "u")
63+
self._validate_element_range(v, "v")
64+
65+
return self._root(u) == self._root(v)
66+
67+
def _root(self, u):
68+
"""
69+
Get the element set root.
70+
This uses the heuristic path compression
71+
See wikipedia article for more details.
72+
"""
73+
74+
if u != self.root[u]:
75+
self.root[u] = self._root(self.root[u])
76+
77+
return self.root[u]
78+
79+
def _validate_element_range(self, u, element_name):
80+
"""
81+
Raises ValueError if element is not in range
82+
"""
83+
if u < 0 or u > self.size:
84+
msg = ("element {0} with value {1} "
85+
"should be in range [0~{2}]")\
86+
.format(element_name, u, self.size)
87+
raise ValueError(msg)

0 commit comments

Comments
 (0)