Skip to content

[mypy] Fix type annotations in data_structures/binary_tree/red_black_tree.py #5739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 69 additions & 52 deletions data_structures/binary_tree/red_black_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def rotate_left(self) -> RedBlackTree:
"""
parent = self.parent
right = self.right
if right is None:
return self
self.right = right.left
if self.right:
self.right.parent = self
Expand All @@ -69,6 +71,8 @@ def rotate_right(self) -> RedBlackTree:
returns the new root to this subtree.
Performing one rotation can be done in O(1).
"""
if self.left is None:
return self
parent = self.parent
left = self.left
self.left = left.right
Expand Down Expand Up @@ -123,23 +127,30 @@ def _insert_repair(self) -> None:
if color(uncle) == 0:
if self.is_left() and self.parent.is_right():
self.parent.rotate_right()
self.right._insert_repair()
if self.right:
self.right._insert_repair()
elif self.is_right() and self.parent.is_left():
self.parent.rotate_left()
self.left._insert_repair()
if self.left:
self.left._insert_repair()
elif self.is_left():
self.grandparent.rotate_right()
self.parent.color = 0
self.parent.right.color = 1
if self.grandparent:
self.grandparent.rotate_right()
self.parent.color = 0
if self.parent.right:
self.parent.right.color = 1
else:
self.grandparent.rotate_left()
self.parent.color = 0
self.parent.left.color = 1
if self.grandparent:
self.grandparent.rotate_left()
self.parent.color = 0
if self.parent.left:
self.parent.left.color = 1
else:
self.parent.color = 0
uncle.color = 0
self.grandparent.color = 1
self.grandparent._insert_repair()
if uncle and self.grandparent:
uncle.color = 0
self.grandparent.color = 1
self.grandparent._insert_repair()

def remove(self, label: int) -> RedBlackTree:
"""Remove label from this tree."""
Expand All @@ -149,8 +160,9 @@ def remove(self, label: int) -> RedBlackTree:
# so we replace this node with the greatest one less than
# it and remove that.
value = self.left.get_max()
self.label = value
self.left.remove(value)
if value is not None:
self.label = value
self.left.remove(value)
else:
# This node has at most one non-None child, so we don't
# need to replace
Expand All @@ -160,10 +172,11 @@ def remove(self, label: int) -> RedBlackTree:
# The only way this happens to a node with one child
# is if both children are None leaves.
# We can just remove this node and call it a day.
if self.is_left():
self.parent.left = None
else:
self.parent.right = None
if self.parent:
if self.is_left():
self.parent.left = None
else:
self.parent.right = None
else:
# The node is black
if child is None:
Expand All @@ -188,7 +201,7 @@ def remove(self, label: int) -> RedBlackTree:
self.left.parent = self
if self.right:
self.right.parent = self
elif self.label > label:
elif self.label is not None and self.label > label:
if self.left:
self.left.remove(label)
else:
Expand All @@ -198,6 +211,13 @@ def remove(self, label: int) -> RedBlackTree:

def _remove_repair(self) -> None:
"""Repair the coloring of the tree that may have been messed up."""
if (
self.parent is None
or self.sibling is None
or self.parent.sibling is None
or self.grandparent is None
):
return
if color(self.sibling) == 1:
self.sibling.color = 0
self.parent.color = 1
Expand Down Expand Up @@ -231,7 +251,8 @@ def _remove_repair(self) -> None:
):
self.sibling.rotate_right()
self.sibling.color = 0
self.sibling.right.color = 1
if self.sibling.right:
self.sibling.right.color = 1
if (
self.is_right()
and color(self.sibling) == 0
Expand All @@ -240,7 +261,8 @@ def _remove_repair(self) -> None:
):
self.sibling.rotate_left()
self.sibling.color = 0
self.sibling.left.color = 1
if self.sibling.left:
self.sibling.left.color = 1
if (
self.is_left()
and color(self.sibling) == 0
Expand Down Expand Up @@ -275,29 +297,25 @@ def check_color_properties(self) -> bool:
"""
# I assume property 1 to hold because there is nothing that can
# make the color be anything other than 0 or 1.

# Property 2
if self.color:
# The root was red
print("Property 2")
return False

# Property 3 does not need to be checked, because None is assumed
# to be black and is all the leaves.

# Property 4
if not self.check_coloring():
print("Property 4")
return False

# Property 5
if self.black_height() is None:
print("Property 5")
return False
# All properties were met
return True

def check_coloring(self) -> None:
def check_coloring(self) -> bool:
"""A helper function to recursively check Property 4 of a
Red-Black Tree. See check_color_properties for more info.
"""
Expand All @@ -310,12 +328,12 @@ def check_coloring(self) -> None:
return False
return True

def black_height(self) -> int:
def black_height(self) -> int | None:
"""Returns the number of black nodes from this node to the
leaves of the tree, or None if there isn't one such value (the
tree is color incorrectly).
"""
if self is None:
if self is None or self.left is None or self.right is None:
# If we're already at a leaf, there is no path
return 1
left = RedBlackTree.black_height(self.left)
Expand All @@ -332,21 +350,21 @@ def black_height(self) -> int:

# Here are functions which are general to all binary search trees

def __contains__(self, label) -> bool:
def __contains__(self, label: int) -> bool:
"""Search through the tree for label, returning True iff it is
found somewhere in the tree.
Guaranteed to run in O(log(n)) time.
"""
return self.search(label) is not None

def search(self, label: int) -> RedBlackTree:
def search(self, label: int) -> RedBlackTree | None:
"""Search through the tree for label, returning its node if
it's found, and None otherwise.
This method is guaranteed to run in O(log(n)) time.
"""
if self.label == label:
return self
elif label > self.label:
elif self.label is not None and label > self.label:
if self.right is None:
return None
else:
Expand All @@ -357,12 +375,12 @@ def search(self, label: int) -> RedBlackTree:
else:
return self.left.search(label)

def floor(self, label: int) -> int:
def floor(self, label: int) -> int | None:
"""Returns the largest element in this tree which is at most label.
This method is guaranteed to run in O(log(n)) time."""
if self.label == label:
return self.label
elif self.label > label:
elif self.label is not None and self.label > label:
if self.left:
return self.left.floor(label)
else:
Expand All @@ -374,13 +392,13 @@ def floor(self, label: int) -> int:
return attempt
return self.label

def ceil(self, label: int) -> int:
def ceil(self, label: int) -> int | None:
"""Returns the smallest element in this tree which is at least label.
This method is guaranteed to run in O(log(n)) time.
"""
if self.label == label:
return self.label
elif self.label < label:
elif self.label is not None and self.label < label:
if self.right:
return self.right.ceil(label)
else:
Expand All @@ -392,7 +410,7 @@ def ceil(self, label: int) -> int:
return attempt
return self.label

def get_max(self) -> int:
def get_max(self) -> int | None:
"""Returns the largest element in this tree.
This method is guaranteed to run in O(log(n)) time.
"""
Expand All @@ -402,7 +420,7 @@ def get_max(self) -> int:
else:
return self.label

def get_min(self) -> int:
def get_min(self) -> int | None:
"""Returns the smallest element in this tree.
This method is guaranteed to run in O(log(n)) time.
"""
Expand All @@ -413,15 +431,15 @@ def get_min(self) -> int:
return self.label

@property
def grandparent(self) -> RedBlackTree:
def grandparent(self) -> RedBlackTree | None:
"""Get the current node's grandparent, or None if it doesn't exist."""
if self.parent is None:
return None
else:
return self.parent.parent

@property
def sibling(self) -> RedBlackTree:
def sibling(self) -> RedBlackTree | None:
"""Get the current node's sibling, or None if it doesn't exist."""
if self.parent is None:
return None
Expand All @@ -432,11 +450,15 @@ def sibling(self) -> RedBlackTree:

def is_left(self) -> bool:
"""Returns true iff this node is the left child of its parent."""
return self.parent and self.parent.left is self
if self.parent is None:
return False
return self.parent.left is self.parent.left is self

def is_right(self) -> bool:
"""Returns true iff this node is the right child of its parent."""
return self.parent and self.parent.right is self
if self.parent is None:
return False
return self.parent.right is self

def __bool__(self) -> bool:
return True
Expand All @@ -452,21 +474,21 @@ def __len__(self) -> int:
ln += len(self.right)
return ln

def preorder_traverse(self) -> Iterator[int]:
def preorder_traverse(self) -> Iterator[int | None]:
yield self.label
if self.left:
yield from self.left.preorder_traverse()
if self.right:
yield from self.right.preorder_traverse()

def inorder_traverse(self) -> Iterator[int]:
def inorder_traverse(self) -> Iterator[int | None]:
if self.left:
yield from self.left.inorder_traverse()
yield self.label
if self.right:
yield from self.right.inorder_traverse()

def postorder_traverse(self) -> Iterator[int]:
def postorder_traverse(self) -> Iterator[int | None]:
if self.left:
yield from self.left.postorder_traverse()
if self.right:
Expand All @@ -488,15 +510,17 @@ def __repr__(self) -> str:
indent=1,
)

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Test if two trees are equal."""
if not isinstance(other, RedBlackTree):
return NotImplemented
if self.label == other.label:
return self.left == other.left and self.right == other.right
else:
return False


def color(node) -> int:
def color(node: RedBlackTree | None) -> int:
"""Returns the color of a node, allowing for None leaves."""
if node is None:
return 0
Expand Down Expand Up @@ -699,19 +723,12 @@ def main() -> None:
>>> pytests()
"""
print_results("Rotating right and left", test_rotations())

print_results("Inserting", test_insert())

print_results("Searching", test_insert_and_search())

print_results("Deleting", test_insert_delete())

print_results("Floor and ceil", test_floor_ceil())

print_results("Tree traversal", test_tree_traversal())

print_results("Tree traversal", test_tree_chaining())

print("Testing tree balancing...")
print("This should only be a few seconds.")
test_insertion_speed()
Expand Down