Skip to content

types: Update binary search tree typehints #7197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions data_structures/binary_tree/binary_search_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
A binary search Tree
"""

from collections.abc import Iterable
from typing import Any


class Node:
def __init__(self, value, parent):
def __init__(self, value: int | None = None):
self.value = value
self.parent = parent # Added in order to delete a node easier
self.left = None
self.right = None
self.parent: Node | None = None # Added in order to delete a node easier
self.left: Node | None = None
self.right: Node | None = None

def __repr__(self):
def __repr__(self) -> str:
from pprint import pformat

if self.left is None and self.right is None:
Expand All @@ -19,16 +22,16 @@ def __repr__(self):


class BinarySearchTree:
def __init__(self, root=None):
def __init__(self, root: Node | None = None):
self.root = root

def __str__(self):
def __str__(self) -> str:
"""
Return a string of all the Nodes using in order traversal
"""
return str(self.root)

def __reassign_nodes(self, node, new_children):
def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
if new_children is not None: # reset its kids
new_children.parent = node.parent
if node.parent is not None: # reset its parent
Expand All @@ -37,23 +40,27 @@ def __reassign_nodes(self, node, new_children):
else:
node.parent.left = new_children
else:
self.root = new_children
self.root = None

def is_right(self, node):
return node == node.parent.right
def is_right(self, node: Node) -> bool:
if node.parent and node.parent.right:
return node == node.parent.right
return False

def empty(self):
def empty(self) -> bool:
return self.root is None

def __insert(self, value):
def __insert(self, value) -> None:
"""
Insert a new node in Binary Search Tree with value label
"""
new_node = Node(value, None) # create a new Node
new_node = Node(value) # create a new Node
if self.empty(): # if Tree is empty
self.root = new_node # set its root
else: # Tree is not empty
parent_node = self.root # from root
if parent_node is None:
return None
while True: # While we don't get to a leaf
if value < parent_node.value: # We go left
if parent_node.left is None:
Expand All @@ -69,12 +76,11 @@ def __insert(self, value):
parent_node = parent_node.right
new_node.parent = parent_node

def insert(self, *values):
def insert(self, *values) -> None:
for value in values:
self.__insert(value)
return self

def search(self, value):
def search(self, value) -> Node | None:
if self.empty():
raise IndexError("Warning: Tree is empty! please use another.")
else:
Expand All @@ -84,30 +90,35 @@ def search(self, value):
node = node.left if value < node.value else node.right
return node

def get_max(self, node=None):
def get_max(self, node: Node | None = None) -> Node | None:
"""
We go deep on the right branch
"""
if node is None:
if self.root is None:
return None
node = self.root

if not self.empty():
while node.right is not None:
node = node.right
return node

def get_min(self, node=None):
def get_min(self, node: Node | None = None) -> Node | None:
"""
We go deep on the left branch
"""
if node is None:
node = self.root
if self.root is None:
return None
if not self.empty():
node = self.root
while node.left is not None:
node = node.left
return node

def remove(self, value):
def remove(self, value: int) -> None:
node = self.search(value) # Look for the node with that label
if node is not None:
if node.left is None and node.right is None: # If it has no children
Expand All @@ -120,18 +131,18 @@ def remove(self, value):
tmp_node = self.get_max(
node.left
) # Gets the max value of the left branch
self.remove(tmp_node.value)
self.remove(tmp_node.value) # type: ignore
node.value = (
tmp_node.value
tmp_node.value # type: ignore
) # Assigns the value to the node to delete and keep tree structure

def preorder_traverse(self, node):
def preorder_traverse(self, node: Node | None) -> Iterable:
if node is not None:
yield node # Preorder Traversal
yield from self.preorder_traverse(node.left)
yield from self.preorder_traverse(node.right)

def traversal_tree(self, traversal_function=None):
def traversal_tree(self, traversal_function=None) -> Any:
"""
This function traversal the tree.
You can pass a function to traversal the tree as needed by client code
Expand All @@ -141,7 +152,7 @@ def traversal_tree(self, traversal_function=None):
else:
return traversal_function(self.root)

def inorder(self, arr: list, node: Node):
def inorder(self, arr: list, node: Node | None) -> None:
"""Perform an inorder traversal and append values of the nodes to
a list named arr"""
if node:
Expand All @@ -151,22 +162,22 @@ def inorder(self, arr: list, node: Node):

def find_kth_smallest(self, k: int, node: Node) -> int:
"""Return the kth smallest element in a binary search tree"""
arr: list = []
arr: list[int] = []
self.inorder(arr, node) # append all values to list using inorder traversal
return arr[k - 1]


def postorder(curr_node):
def postorder(curr_node: Node | None) -> list[Node]:
"""
postOrder (left, right, self)
"""
node_list = list()
node_list = []
if curr_node is not None:
node_list = postorder(curr_node.left) + postorder(curr_node.right) + [curr_node]
return node_list


def binary_search_tree():
def binary_search_tree() -> None:
r"""
Example
8
Expand All @@ -177,7 +188,8 @@ def binary_search_tree():
/ \ /
4 7 13

>>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
>>> t = BinarySearchTree()
>>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
>>> print(" ".join(repr(i.value) for i in t.traversal_tree()))
8 3 1 6 4 7 10 14 13
>>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder)))
Expand Down Expand Up @@ -206,8 +218,8 @@ def binary_search_tree():
print("The value -1 doesn't exist")

if not t.empty():
print("Max Value: ", t.get_max().value)
print("Min Value: ", t.get_min().value)
print("Max Value: ", t.get_max().value) # type: ignore
print("Min Value: ", t.get_min().value) # type: ignore

for i in testlist:
t.remove(i)
Expand All @@ -217,5 +229,4 @@ def binary_search_tree():
if __name__ == "__main__":
import doctest

doctest.testmod()
# binary_search_tree()
doctest.testmod(verbose=True)