Skip to content

Created Union-Find algorithm #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
77 changes: 77 additions & 0 deletions data_structures/UnionFind/tests_union_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from union_find import UnionFind
import unittest


class TestUnionFind(unittest.TestCase):
def test_init_with_valid_size(self):
uf = UnionFind(5)
self.assertEqual(uf.size, 5)

def test_init_with_invalid_size(self):
with self.assertRaises(ValueError):
uf = UnionFind(0)

with self.assertRaises(ValueError):
uf = UnionFind(-5)

def test_union_with_valid_values(self):
uf = UnionFind(10)

for i in range(11):
for j in range(11):
uf.union(i, j)

def test_union_with_invalid_values(self):
uf = UnionFind(10)

with self.assertRaises(ValueError):
uf.union(-1, 1)

with self.assertRaises(ValueError):
uf.union(11, 1)

def test_same_set_with_valid_values(self):
uf = UnionFind(10)

for i in range(11):
for j in range(11):
if i == j:
self.assertTrue(uf.same_set(i, j))
else:
self.assertFalse(uf.same_set(i, j))

uf.union(1, 2)
self.assertTrue(uf.same_set(1, 2))

uf.union(3, 4)
self.assertTrue(uf.same_set(3, 4))

self.assertFalse(uf.same_set(1, 3))
self.assertFalse(uf.same_set(1, 4))
self.assertFalse(uf.same_set(2, 3))
self.assertFalse(uf.same_set(2, 4))

uf.union(1, 3)
self.assertTrue(uf.same_set(1, 3))
self.assertTrue(uf.same_set(1, 4))
self.assertTrue(uf.same_set(2, 3))
self.assertTrue(uf.same_set(2, 4))

uf.union(4, 10)
self.assertTrue(uf.same_set(1, 10))
self.assertTrue(uf.same_set(2, 10))
self.assertTrue(uf.same_set(3, 10))
self.assertTrue(uf.same_set(4, 10))

def test_same_set_with_invalid_values(self):
uf = UnionFind(10)

with self.assertRaises(ValueError):
uf.same_set(-1, 1)

with self.assertRaises(ValueError):
uf.same_set(11, 0)


if __name__ == '__main__':
unittest.main()
87 changes: 87 additions & 0 deletions data_structures/UnionFind/union_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
class UnionFind():
"""
https://en.wikipedia.org/wiki/Disjoint-set_data_structure

The union-find is a disjoint-set data structure

You can merge two sets and tell if one set belongs to
another one.

It's used on the Kruskal Algorithm
(https://en.wikipedia.org/wiki/Kruskal%27s_algorithm)

The elements are in range [0, size]
"""
def __init__(self, size):
if size <= 0:
raise ValueError("size should be greater than 0")

self.size = size

# The below plus 1 is because we are using elements
# in range [0, size]. It makes more sense.

# Every set begins with only itself
self.root = [i for i in range(size+1)]

# This is used for heuristic union by rank
self.weight = [0 for i in range(size+1)]

def union(self, u, v):
"""
Union of the sets u and v.
Complexity: log(n).
Amortized complexity: < 5 (it's very fast).
"""

self._validate_element_range(u, "u")
self._validate_element_range(v, "v")

if u == v:
return

# Using union by rank will guarantee the
# log(n) complexity
rootu = self._root(u)
rootv = self._root(v)
weight_u = self.weight[rootu]
weight_v = self.weight[rootv]
if weight_u >= weight_v:
self.root[rootv] = rootu
if weight_u == weight_v:
self.weight[rootu] += 1
else:
self.root[rootu] = rootv

def same_set(self, u, v):
"""
Return true if the elements u and v belongs to
the same set
"""

self._validate_element_range(u, "u")
self._validate_element_range(v, "v")

return self._root(u) == self._root(v)

def _root(self, u):
"""
Get the element set root.
This uses the heuristic path compression
See wikipedia article for more details.
"""

if u != self.root[u]:
self.root[u] = self._root(self.root[u])

return self.root[u]

def _validate_element_range(self, u, element_name):
"""
Raises ValueError if element is not in range
"""
if u < 0 or u > self.size:
msg = ("element {0} with value {1} "
"should be in range [0~{2}]")\
.format(element_name, u, self.size)
raise ValueError(msg)