Skip to content

[mypy] Add/fix type annotations for avl tree in data structures #4214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 87 additions & 68 deletions data_structures/binary_tree/avl_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,84 +8,85 @@

import math
import random
from typing import Any, List, Optional


class my_queue:
def __init__(self):
self.data = []
self.head = 0
self.tail = 0
def __init__(self) -> None:
self.data: List[Any] = []
self.head: int = 0
self.tail: int = 0

def is_empty(self):
def is_empty(self) -> bool:
return self.head == self.tail

def push(self, data):
def push(self, data: Any) -> None:
self.data.append(data)
self.tail = self.tail + 1

def pop(self):
def pop(self) -> Any:
ret = self.data[self.head]
self.head = self.head + 1
return ret

def count(self):
def count(self) -> int:
return self.tail - self.head

def print(self):
def print(self) -> None:
print(self.data)
print("**************")
print(self.data[self.head : self.tail])


class my_node:
def __init__(self, data):
def __init__(self, data: Any) -> None:
self.data = data
self.left = None
self.right = None
self.height = 1
self.left: Optional[my_node] = None
self.right: Optional[my_node] = None
self.height: int = 1

def get_data(self):
def get_data(self) -> Any:
return self.data

def get_left(self):
def get_left(self) -> Optional["my_node"]:
return self.left

def get_right(self):
def get_right(self) -> Optional["my_node"]:
return self.right

def get_height(self):
def get_height(self) -> int:
return self.height

def set_data(self, data):
def set_data(self, data: Any) -> None:
self.data = data
return

def set_left(self, node):
def set_left(self, node: Optional["my_node"]) -> None:
self.left = node
return

def set_right(self, node):
def set_right(self, node: Optional["my_node"]) -> None:
self.right = node
return

def set_height(self, height):
def set_height(self, height: int) -> None:
self.height = height
return


def get_height(node):
def get_height(node: Optional["my_node"]) -> int:
if node is None:
return 0
return node.get_height()


def my_max(a, b):
def my_max(a: int, b: int) -> int:
if a > b:
return a
return b


def right_rotation(node):
def right_rotation(node: my_node) -> my_node:
r"""
A B
/ \ / \
Expand All @@ -98,6 +99,7 @@ def right_rotation(node):
"""
print("left rotation node:", node.get_data())
ret = node.get_left()
assert ret is not None
node.set_left(ret.get_right())
ret.set_right(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
Expand All @@ -107,12 +109,13 @@ def right_rotation(node):
return ret


def left_rotation(node):
def left_rotation(node: my_node) -> my_node:
"""
a mirror symmetry rotation of the left_rotation
"""
print("right rotation node:", node.get_data())
ret = node.get_right()
assert ret is not None
node.set_right(ret.get_left())
ret.set_left(node)
h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1
Expand All @@ -122,7 +125,7 @@ def left_rotation(node):
return ret


def lr_rotation(node):
def lr_rotation(node: my_node) -> my_node:
r"""
A A Br
/ \ / \ / \
Expand All @@ -133,33 +136,41 @@ def lr_rotation(node):
UB Bl
RR = right_rotation LR = left_rotation
"""
node.set_left(left_rotation(node.get_left()))
left_child = node.get_left()
assert left_child is not None
node.set_left(left_rotation(left_child))
return right_rotation(node)


def rl_rotation(node):
node.set_right(right_rotation(node.get_right()))
def rl_rotation(node: my_node) -> my_node:
right_child = node.get_right()
assert right_child is not None
node.set_right(right_rotation(right_child))
return left_rotation(node)


def insert_node(node, data):
def insert_node(node: Optional["my_node"], data: Any) -> Optional["my_node"]:
if node is None:
return my_node(data)
if data < node.get_data():
node.set_left(insert_node(node.get_left(), data))
if (
get_height(node.get_left()) - get_height(node.get_right()) == 2
): # an unbalance detected
left_child = node.get_left()
assert left_child is not None
if (
data < node.get_left().get_data()
data < left_child.get_data()
): # new node is the left child of the left child
node = right_rotation(node)
else:
node = lr_rotation(node)
else:
node.set_right(insert_node(node.get_right(), data))
if get_height(node.get_right()) - get_height(node.get_left()) == 2:
if data < node.get_right().get_data():
right_child = node.get_right()
assert right_child is not None
if data < right_child.get_data():
node = rl_rotation(node)
else:
node = left_rotation(node)
Expand All @@ -168,52 +179,59 @@ def insert_node(node, data):
return node


def get_rightMost(root):
while root.get_right() is not None:
root = root.get_right()
def get_rightMost(root: my_node) -> Any:
while True:
right_child = root.get_right()
if right_child is None:
break
root = right_child
return root.get_data()


def get_leftMost(root):
while root.get_left() is not None:
root = root.get_left()
def get_leftMost(root: my_node) -> Any:
while True:
left_child = root.get_left()
if left_child is None:
break
root = left_child
return root.get_data()


def del_node(root, data):
def del_node(root: my_node, data: Any) -> Optional["my_node"]:
left_child = root.get_left()
right_child = root.get_right()
if root.get_data() == data:
if root.get_left() is not None and root.get_right() is not None:
temp_data = get_leftMost(root.get_right())
if left_child is not None and right_child is not None:
temp_data = get_leftMost(right_child)
root.set_data(temp_data)
root.set_right(del_node(root.get_right(), temp_data))
elif root.get_left() is not None:
root = root.get_left()
root.set_right(del_node(right_child, temp_data))
elif left_child is not None:
root = left_child
elif right_child is not None:
root = right_child
else:
root = root.get_right()
return None
elif root.get_data() > data:
if root.get_left() is None:
if left_child is None:
print("No such data")
return root
else:
root.set_left(del_node(root.get_left(), data))
elif root.get_data() < data:
if root.get_right() is None:
root.set_left(del_node(left_child, data))
else: # root.get_data() < data
if right_child is None:
return root
else:
root.set_right(del_node(root.get_right(), data))
if root is None:
return root
if get_height(root.get_right()) - get_height(root.get_left()) == 2:
if get_height(root.get_right().get_right()) > get_height(
root.get_right().get_left()
):
root.set_right(del_node(right_child, data))

if get_height(right_child) - get_height(left_child) == 2:
assert right_child is not None
if get_height(right_child.get_right()) > get_height(right_child.get_left()):
root = left_rotation(root)
else:
root = rl_rotation(root)
elif get_height(root.get_right()) - get_height(root.get_left()) == -2:
if get_height(root.get_left().get_left()) > get_height(
root.get_left().get_right()
):
elif get_height(right_child) - get_height(left_child) == -2:
assert left_child is not None
if get_height(left_child.get_left()) > get_height(left_child.get_right()):
root = right_rotation(root)
else:
root = lr_rotation(root)
Expand Down Expand Up @@ -256,25 +274,26 @@ class AVLtree:
*************************************
"""

def __init__(self):
self.root = None
def __init__(self) -> None:
self.root: Optional[my_node] = None

def get_height(self):
# print("yyy")
def get_height(self) -> int:
return get_height(self.root)

def insert(self, data):
def insert(self, data: Any) -> None:
print("insert:" + str(data))
self.root = insert_node(self.root, data)

def del_node(self, data):
def del_node(self, data: Any) -> None:
print("delete:" + str(data))
if self.root is None:
print("Tree is empty!")
return
self.root = del_node(self.root, data)

def __str__(self): # a level traversale, gives a more intuitive look on the tree
def __str__(
self,
) -> str: # a level traversale, gives a more intuitive look on the tree
output = ""
q = my_queue()
q.push(self.root)
Expand Down Expand Up @@ -308,7 +327,7 @@ def __str__(self): # a level traversale, gives a more intuitive look on the tre
return output


def _test():
def _test() -> None:
import doctest

doctest.testmod()
Expand Down