diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 3362610b9303..6227ba876957 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -8,84 +8,86 @@ import math import random +import unittest +from typing import Any class my_queue: - def __init__(self): + def __init__(self) -> None: self.data = [] self.head = 0 self.tail = 0 - def is_empty(self): + def is_empty(self) -> bool: return self.head == self.tail - def push(self, data): + def push(self, data: Any) -> None: self.data.append(data) self.tail = self.tail + 1 - def pop(self): + def pop(self) -> Any: ret = self.data[self.head] self.head = self.head + 1 return ret - def count(self): + def count(self) -> int: return self.tail - self.head - def print(self): + def print(self) -> None: print(self.data) print("**************") print(self.data[self.head : self.tail]) class my_node: - def __init__(self, data): + def __init__(self, data: Any) -> None: self.data = data self.left = None self.right = None self.height = 1 - def get_data(self): + def get_data(self) -> Any: return self.data - def get_left(self): + def get_left(self) -> "my_node": return self.left - def get_right(self): + def get_right(self) -> "my_node": return self.right - def get_height(self): + def get_height(self) -> int: return self.height - def set_data(self, data): + def set_data(self, data: Any) -> None: self.data = data return - def set_left(self, node): + def set_left(self, node: "my_node") -> None: self.left = node return - def set_right(self, node): + def set_right(self, node: "my_node") -> None: self.right = node return - def set_height(self, height): + def set_height(self, height: int) -> None: self.height = height return -def get_height(node): +def get_height(node: "my_node") -> int: if node is None: return 0 return node.get_height() -def my_max(a, b): +def my_max(a: Any, b: Any) -> Any: if a > b: return a return b -def right_rotation(node): +def right_rotation(node: "my_node") -> "my_node": r""" A B / \ / \ @@ -107,7 +109,7 @@ def right_rotation(node): return ret -def left_rotation(node): +def left_rotation(node: "my_node") -> "my_node": """ a mirror symmetry rotation of the left_rotation """ @@ -122,7 +124,7 @@ def left_rotation(node): return ret -def lr_rotation(node): +def lr_rotation(node: "my_node") -> "my_node": r""" A A Br / \ / \ / \ @@ -137,12 +139,12 @@ def lr_rotation(node): return right_rotation(node) -def rl_rotation(node): +def rl_rotation(node: "my_node") -> "my_node": node.set_right(right_rotation(node.get_right())) return left_rotation(node) -def insert_node(node, data): +def insert_node(node: "my_node", data: Any) -> "my_node": if node is None: return my_node(data) if data < node.get_data(): @@ -168,19 +170,19 @@ def insert_node(node, data): return node -def get_rightMost(root): +def get_rightMost(root: "my_node") -> "my_node": while root.get_right() is not None: root = root.get_right() return root.get_data() -def get_leftMost(root): +def get_leftMost(root: "my_node") -> "my_node": while root.get_left() is not None: root = root.get_left() return root.get_data() -def del_node(root, data): +def del_node(root: "my_node", data: Any) -> "my_node": if root.get_data() == data: if root.get_left() is not None and root.get_right() is not None: temp_data = get_leftMost(root.get_right()) @@ -204,14 +206,14 @@ def del_node(root, data): if root is None: return root if get_height(root.get_right()) - get_height(root.get_left()) == 2: - if get_height(root.get_right().get_right()) > get_height( + if get_height(root.get_right().get_right()) >= get_height( root.get_right().get_left() ): root = left_rotation(root) else: root = rl_rotation(root) elif get_height(root.get_right()) - get_height(root.get_left()) == -2: - if get_height(root.get_left().get_left()) > get_height( + if get_height(root.get_left().get_left()) >= get_height( root.get_left().get_right() ): root = right_rotation(root) @@ -256,25 +258,27 @@ class AVLtree: ************************************* """ - def __init__(self): + def __init__(self) -> None: self.root = None - def get_height(self): - # print("yyy") + def get_height(self) -> int: return get_height(self.root) - def insert(self, data): + def insert(self, data: Any) -> None: print("insert:" + str(data)) self.root = insert_node(self.root, data) - def del_node(self, data): + def del_node(self, data: Any) -> None: print("delete:" + str(data)) if self.root is None: print("Tree is empty!") return self.root = del_node(self.root, data) - def __str__(self): # a level traversale, gives a more intuitive look on the tree + def __str__(self) -> str: + """ + A level traversale, gives a more intuitive look on the tree + """ output = "" q = my_queue() q.push(self.root) @@ -308,21 +312,51 @@ def __str__(self): # a level traversale, gives a more intuitive look on the tre return output -def _test(): - import doctest - - doctest.testmod() +class Test(unittest.TestCase): + def _is_balance(self, avl: AVLtree): + if avl.root is None: + return True + dfs = [avl.root] + while dfs: + now = dfs.pop() + if now.left: + left_height = now.left.height + dfs.append(now.left) + else: + left_height = 0 + if now.right: + right_height = now.right.height + dfs.append(now.right) + else: + right_height = 0 + if abs(left_height - right_height) > 1: + return False + return True + + def test_delete(self): + avl = AVLtree() + for i in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]: + avl.insert(i) + self.assertTrue(self._is_balance(avl)) + + for v in [8, 7, 4, 3, 9, 10, 11, 13, 6, 0, 2, 12, 1, 14, 5]: + avl.del_node(v) + print(avl) + self.assertTrue(self._is_balance(avl)) + + def test_delete_random(self): + avl = AVLtree() + random.seed(0) + values = list(range(1000)) + random.shuffle(values) + for i in values: + avl.insert(i) + self.assertTrue(self._is_balance(avl)) + random.shuffle(values) + for i in values: + avl.del_node(i) + self.assertTrue(self._is_balance(avl)) if __name__ == "__main__": - _test() - t = AVLtree() - lst = list(range(10)) - random.shuffle(lst) - for i in lst: - t.insert(i) - print(str(t)) - random.shuffle(lst) - for i in lst: - t.del_node(i) - print(str(t)) + unittest.main()