Skip to content

[mypy] Fix type annotations in wavelet_tree.py #5641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 28, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions data_structures/binary_tree/wavelet_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __repr__(self) -> str:
return f"min_value: {self.minn}, max_value: {self.maxx}"


def build_tree(arr: list[int]) -> Node:
def build_tree(arr: list[int]) -> Node | None:
"""
Builds the tree for arr and returns the root
of the constructed tree
Expand All @@ -51,7 +51,10 @@ def build_tree(arr: list[int]) -> Node:
then recursively build trees for left_arr and right_arr
"""
pivot = (root.minn + root.maxx) // 2
left_arr, right_arr = [], []

left_arr: list[int] = []
right_arr: list[int] = []

for index, num in enumerate(arr):
if num <= pivot:
left_arr.append(num)
Expand All @@ -63,7 +66,7 @@ def build_tree(arr: list[int]) -> Node:
return root


def rank_till_index(node: Node, num: int, index: int) -> int:
def rank_till_index(node: Node | None, num: int, index: int) -> int:
"""
Returns the number of occurrences of num in interval [0, index] in the list

Expand All @@ -79,7 +82,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int:
>>> rank_till_index(root, 0, 9)
1
"""
if index < 0:
if index < 0 or node is None:
return 0
# Leaf node cases
if node.minn == node.maxx:
Expand All @@ -93,7 +96,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int:
return rank_till_index(node.right, num, index - node.map_left[index])


def rank(node: Node, num: int, start: int, end: int) -> int:
def rank(node: Node | None, num: int, start: int, end: int) -> int:
"""
Returns the number of occurrences of num in interval [start, end] in the list

Expand All @@ -114,7 +117,7 @@ def rank(node: Node, num: int, start: int, end: int) -> int:
return rank_till_end - rank_before_start


def quantile(node: Node, index: int, start: int, end: int) -> int:
def quantile(node: Node | None, index: int, start: int, end: int) -> int:
"""
Returns the index'th smallest element in interval [start, end] in the list
index is 0-indexed
Expand All @@ -129,7 +132,7 @@ def quantile(node: Node, index: int, start: int, end: int) -> int:
>>> quantile(root, 4, 2, 5)
-1
"""
if index > (end - start) or start > end:
if index > (end - start) or start > end or node is None:
return -1
# Leaf node case
if node.minn == node.maxx:
Expand All @@ -155,10 +158,10 @@ def quantile(node: Node, index: int, start: int, end: int) -> int:


def range_counting(
node: Node, start: int, end: int, start_num: int, end_num: int
node: Node | None, start: int, end: int, start_num: int, end_num: int
) -> int:
"""
Returns the number of elememts in range [start_num, end_num]
Returns the number of elements in range [start_num, end_num]
in interval [start, end] in the list

>>> root = build_tree(test_array)
Expand All @@ -175,6 +178,7 @@ def range_counting(
"""
if (
start > end
or node is None
or start_num > end_num
or node.minn > end_num
or node.maxx < start_num
Expand Down