Skip to content

Commit 0590d73

Browse files
authored
[mypy] Fix type annotations in wavelet_tree.py (#5641)
* [mypy] Fix type annotations for wavelet_tree.py * fix a typo
1 parent 61e1dd2 commit 0590d73

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

Diff for: data_structures/binary_tree/wavelet_tree.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __repr__(self) -> str:
3131
return f"min_value: {self.minn}, max_value: {self.maxx}"
3232

3333

34-
def build_tree(arr: list[int]) -> Node:
34+
def build_tree(arr: list[int]) -> Node | None:
3535
"""
3636
Builds the tree for arr and returns the root
3737
of the constructed tree
@@ -51,7 +51,10 @@ def build_tree(arr: list[int]) -> Node:
5151
then recursively build trees for left_arr and right_arr
5252
"""
5353
pivot = (root.minn + root.maxx) // 2
54-
left_arr, right_arr = [], []
54+
55+
left_arr: list[int] = []
56+
right_arr: list[int] = []
57+
5558
for index, num in enumerate(arr):
5659
if num <= pivot:
5760
left_arr.append(num)
@@ -63,7 +66,7 @@ def build_tree(arr: list[int]) -> Node:
6366
return root
6467

6568

66-
def rank_till_index(node: Node, num: int, index: int) -> int:
69+
def rank_till_index(node: Node | None, num: int, index: int) -> int:
6770
"""
6871
Returns the number of occurrences of num in interval [0, index] in the list
6972
@@ -79,7 +82,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int:
7982
>>> rank_till_index(root, 0, 9)
8083
1
8184
"""
82-
if index < 0:
85+
if index < 0 or node is None:
8386
return 0
8487
# Leaf node cases
8588
if node.minn == node.maxx:
@@ -93,7 +96,7 @@ def rank_till_index(node: Node, num: int, index: int) -> int:
9396
return rank_till_index(node.right, num, index - node.map_left[index])
9497

9598

96-
def rank(node: Node, num: int, start: int, end: int) -> int:
99+
def rank(node: Node | None, num: int, start: int, end: int) -> int:
97100
"""
98101
Returns the number of occurrences of num in interval [start, end] in the list
99102
@@ -114,7 +117,7 @@ def rank(node: Node, num: int, start: int, end: int) -> int:
114117
return rank_till_end - rank_before_start
115118

116119

117-
def quantile(node: Node, index: int, start: int, end: int) -> int:
120+
def quantile(node: Node | None, index: int, start: int, end: int) -> int:
118121
"""
119122
Returns the index'th smallest element in interval [start, end] in the list
120123
index is 0-indexed
@@ -129,7 +132,7 @@ def quantile(node: Node, index: int, start: int, end: int) -> int:
129132
>>> quantile(root, 4, 2, 5)
130133
-1
131134
"""
132-
if index > (end - start) or start > end:
135+
if index > (end - start) or start > end or node is None:
133136
return -1
134137
# Leaf node case
135138
if node.minn == node.maxx:
@@ -155,10 +158,10 @@ def quantile(node: Node, index: int, start: int, end: int) -> int:
155158

156159

157160
def range_counting(
158-
node: Node, start: int, end: int, start_num: int, end_num: int
161+
node: Node | None, start: int, end: int, start_num: int, end_num: int
159162
) -> int:
160163
"""
161-
Returns the number of elememts in range [start_num, end_num]
164+
Returns the number of elements in range [start_num, end_num]
162165
in interval [start, end] in the list
163166
164167
>>> root = build_tree(test_array)
@@ -175,6 +178,7 @@ def range_counting(
175178
"""
176179
if (
177180
start > end
181+
or node is None
178182
or start_num > end_num
179183
or node.minn > end_num
180184
or node.maxx < start_num

0 commit comments

Comments
 (0)