Skip to content

Commit 2595cf0

Browse files
[mypy] Add/fix type annotations for binary trees in data structures (#4085)
* fix mypy: data_structures:binary_tree * mypy --strict for binary_trees in data_structures * fix pre-commit Co-authored-by: LiHao <[email protected]>
1 parent 97b6ca2 commit 2595cf0

File tree

3 files changed

+84
-57
lines changed

3 files changed

+84
-57
lines changed

Diff for: data_structures/binary_tree/binary_search_tree_recursive.py

+62-37
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,22 @@
88
python binary_search_tree_recursive.py
99
"""
1010
import unittest
11+
from typing import Iterator, Optional
1112

1213

1314
class Node:
14-
def __init__(self, label: int, parent):
15+
def __init__(self, label: int, parent: Optional["Node"]) -> None:
1516
self.label = label
1617
self.parent = parent
17-
self.left = None
18-
self.right = None
18+
self.left: Optional[Node] = None
19+
self.right: Optional[Node] = None
1920

2021

2122
class BinarySearchTree:
22-
def __init__(self):
23-
self.root = None
23+
def __init__(self) -> None:
24+
self.root: Optional[Node] = None
2425

25-
def empty(self):
26+
def empty(self) -> None:
2627
"""
2728
Empties the tree
2829
@@ -46,7 +47,7 @@ def is_empty(self) -> bool:
4647
"""
4748
return self.root is None
4849

49-
def put(self, label: int):
50+
def put(self, label: int) -> None:
5051
"""
5152
Put a new node in the tree
5253
@@ -65,7 +66,9 @@ def put(self, label: int):
6566
"""
6667
self.root = self._put(self.root, label)
6768

68-
def _put(self, node: Node, label: int, parent: Node = None) -> Node:
69+
def _put(
70+
self, node: Optional[Node], label: int, parent: Optional[Node] = None
71+
) -> Node:
6972
if node is None:
7073
node = Node(label, parent)
7174
else:
@@ -95,7 +98,7 @@ def search(self, label: int) -> Node:
9598
"""
9699
return self._search(self.root, label)
97100

98-
def _search(self, node: Node, label: int) -> Node:
101+
def _search(self, node: Optional[Node], label: int) -> Node:
99102
if node is None:
100103
raise Exception(f"Node with label {label} does not exist")
101104
else:
@@ -106,7 +109,7 @@ def _search(self, node: Node, label: int) -> Node:
106109

107110
return node
108111

109-
def remove(self, label: int):
112+
def remove(self, label: int) -> None:
110113
"""
111114
Removes a node in the tree
112115
@@ -122,22 +125,22 @@ def remove(self, label: int):
122125
Exception: Node with label 3 does not exist
123126
"""
124127
node = self.search(label)
125-
if not node.right and not node.left:
126-
self._reassign_nodes(node, None)
127-
elif not node.right and node.left:
128-
self._reassign_nodes(node, node.left)
129-
elif node.right and not node.left:
130-
self._reassign_nodes(node, node.right)
131-
else:
128+
if node.right and node.left:
132129
lowest_node = self._get_lowest_node(node.right)
133130
lowest_node.left = node.left
134131
lowest_node.right = node.right
135132
node.left.parent = lowest_node
136133
if node.right:
137134
node.right.parent = lowest_node
138135
self._reassign_nodes(node, lowest_node)
136+
elif not node.right and node.left:
137+
self._reassign_nodes(node, node.left)
138+
elif node.right and not node.left:
139+
self._reassign_nodes(node, node.right)
140+
else:
141+
self._reassign_nodes(node, None)
139142

140-
def _reassign_nodes(self, node: Node, new_children: Node):
143+
def _reassign_nodes(self, node: Node, new_children: Optional[Node]) -> None:
141144
if new_children:
142145
new_children.parent = node.parent
143146

@@ -192,7 +195,7 @@ def get_max_label(self) -> int:
192195
>>> t.get_max_label()
193196
10
194197
"""
195-
if self.is_empty():
198+
if self.root is None:
196199
raise Exception("Binary search tree is empty")
197200

198201
node = self.root
@@ -216,7 +219,7 @@ def get_min_label(self) -> int:
216219
>>> t.get_min_label()
217220
8
218221
"""
219-
if self.is_empty():
222+
if self.root is None:
220223
raise Exception("Binary search tree is empty")
221224

222225
node = self.root
@@ -225,7 +228,7 @@ def get_min_label(self) -> int:
225228

226229
return node.label
227230

228-
def inorder_traversal(self) -> list:
231+
def inorder_traversal(self) -> Iterator[Node]:
229232
"""
230233
Return the inorder traversal of the tree
231234
@@ -241,13 +244,13 @@ def inorder_traversal(self) -> list:
241244
"""
242245
return self._inorder_traversal(self.root)
243246

244-
def _inorder_traversal(self, node: Node) -> list:
247+
def _inorder_traversal(self, node: Optional[Node]) -> Iterator[Node]:
245248
if node is not None:
246249
yield from self._inorder_traversal(node.left)
247250
yield node
248251
yield from self._inorder_traversal(node.right)
249252

250-
def preorder_traversal(self) -> list:
253+
def preorder_traversal(self) -> Iterator[Node]:
251254
"""
252255
Return the preorder traversal of the tree
253256
@@ -263,7 +266,7 @@ def preorder_traversal(self) -> list:
263266
"""
264267
return self._preorder_traversal(self.root)
265268

266-
def _preorder_traversal(self, node: Node) -> list:
269+
def _preorder_traversal(self, node: Optional[Node]) -> Iterator[Node]:
267270
if node is not None:
268271
yield node
269272
yield from self._preorder_traversal(node.left)
@@ -272,7 +275,7 @@ def _preorder_traversal(self, node: Node) -> list:
272275

273276
class BinarySearchTreeTest(unittest.TestCase):
274277
@staticmethod
275-
def _get_binary_search_tree():
278+
def _get_binary_search_tree() -> BinarySearchTree:
276279
r"""
277280
8
278281
/ \
@@ -298,14 +301,15 @@ def _get_binary_search_tree():
298301

299302
return t
300303

301-
def test_put(self):
304+
def test_put(self) -> None:
302305
t = BinarySearchTree()
303306
assert t.is_empty()
304307

305308
t.put(8)
306309
r"""
307310
8
308311
"""
312+
assert t.root is not None
309313
assert t.root.parent is None
310314
assert t.root.label == 8
311315

@@ -315,6 +319,7 @@ def test_put(self):
315319
\
316320
10
317321
"""
322+
assert t.root.right is not None
318323
assert t.root.right.parent == t.root
319324
assert t.root.right.label == 10
320325

@@ -324,6 +329,7 @@ def test_put(self):
324329
/ \
325330
3 10
326331
"""
332+
assert t.root.left is not None
327333
assert t.root.left.parent == t.root
328334
assert t.root.left.label == 3
329335

@@ -335,6 +341,7 @@ def test_put(self):
335341
\
336342
6
337343
"""
344+
assert t.root.left.right is not None
338345
assert t.root.left.right.parent == t.root.left
339346
assert t.root.left.right.label == 6
340347

@@ -346,13 +353,14 @@ def test_put(self):
346353
/ \
347354
1 6
348355
"""
356+
assert t.root.left.left is not None
349357
assert t.root.left.left.parent == t.root.left
350358
assert t.root.left.left.label == 1
351359

352360
with self.assertRaises(Exception):
353361
t.put(1)
354362

355-
def test_search(self):
363+
def test_search(self) -> None:
356364
t = self._get_binary_search_tree()
357365

358366
node = t.search(6)
@@ -364,7 +372,7 @@ def test_search(self):
364372
with self.assertRaises(Exception):
365373
t.search(2)
366374

367-
def test_remove(self):
375+
def test_remove(self) -> None:
368376
t = self._get_binary_search_tree()
369377

370378
t.remove(13)
@@ -379,6 +387,9 @@ def test_remove(self):
379387
\
380388
5
381389
"""
390+
assert t.root is not None
391+
assert t.root.right is not None
392+
assert t.root.right.right is not None
382393
assert t.root.right.right.right is None
383394
assert t.root.right.right.left is None
384395

@@ -394,6 +405,9 @@ def test_remove(self):
394405
\
395406
5
396407
"""
408+
assert t.root.left is not None
409+
assert t.root.left.right is not None
410+
assert t.root.left.right.left is not None
397411
assert t.root.left.right.right is None
398412
assert t.root.left.right.left.label == 4
399413

@@ -407,6 +421,8 @@ def test_remove(self):
407421
\
408422
5
409423
"""
424+
assert t.root.left.left is not None
425+
assert t.root.left.right.right is not None
410426
assert t.root.left.left.label == 1
411427
assert t.root.left.right.label == 4
412428
assert t.root.left.right.right.label == 5
@@ -422,6 +438,7 @@ def test_remove(self):
422438
/ \ \
423439
1 5 14
424440
"""
441+
assert t.root is not None
425442
assert t.root.left.label == 4
426443
assert t.root.left.right.label == 5
427444
assert t.root.left.left.label == 1
@@ -437,13 +454,15 @@ def test_remove(self):
437454
/ \
438455
1 14
439456
"""
457+
assert t.root.left is not None
458+
assert t.root.left.left is not None
440459
assert t.root.left.label == 5
441460
assert t.root.left.right is None
442461
assert t.root.left.left.label == 1
443462
assert t.root.left.parent == t.root
444463
assert t.root.left.left.parent == t.root.left
445464

446-
def test_remove_2(self):
465+
def test_remove_2(self) -> None:
447466
t = self._get_binary_search_tree()
448467

449468
t.remove(3)
@@ -456,6 +475,12 @@ def test_remove_2(self):
456475
/ \ /
457476
5 7 13
458477
"""
478+
assert t.root is not None
479+
assert t.root.left is not None
480+
assert t.root.left.left is not None
481+
assert t.root.left.right is not None
482+
assert t.root.left.right.left is not None
483+
assert t.root.left.right.right is not None
459484
assert t.root.left.label == 4
460485
assert t.root.left.right.label == 6
461486
assert t.root.left.left.label == 1
@@ -466,25 +491,25 @@ def test_remove_2(self):
466491
assert t.root.left.left.parent == t.root.left
467492
assert t.root.left.right.left.parent == t.root.left.right
468493

469-
def test_empty(self):
494+
def test_empty(self) -> None:
470495
t = self._get_binary_search_tree()
471496
t.empty()
472497
assert t.root is None
473498

474-
def test_is_empty(self):
499+
def test_is_empty(self) -> None:
475500
t = self._get_binary_search_tree()
476501
assert not t.is_empty()
477502

478503
t.empty()
479504
assert t.is_empty()
480505

481-
def test_exists(self):
506+
def test_exists(self) -> None:
482507
t = self._get_binary_search_tree()
483508

484509
assert t.exists(6)
485510
assert not t.exists(-1)
486511

487-
def test_get_max_label(self):
512+
def test_get_max_label(self) -> None:
488513
t = self._get_binary_search_tree()
489514

490515
assert t.get_max_label() == 14
@@ -493,7 +518,7 @@ def test_get_max_label(self):
493518
with self.assertRaises(Exception):
494519
t.get_max_label()
495520

496-
def test_get_min_label(self):
521+
def test_get_min_label(self) -> None:
497522
t = self._get_binary_search_tree()
498523

499524
assert t.get_min_label() == 1
@@ -502,20 +527,20 @@ def test_get_min_label(self):
502527
with self.assertRaises(Exception):
503528
t.get_min_label()
504529

505-
def test_inorder_traversal(self):
530+
def test_inorder_traversal(self) -> None:
506531
t = self._get_binary_search_tree()
507532

508533
inorder_traversal_nodes = [i.label for i in t.inorder_traversal()]
509534
assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14]
510535

511-
def test_preorder_traversal(self):
536+
def test_preorder_traversal(self) -> None:
512537
t = self._get_binary_search_tree()
513538

514539
preorder_traversal_nodes = [i.label for i in t.preorder_traversal()]
515540
assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13]
516541

517542

518-
def binary_search_tree_example():
543+
def binary_search_tree_example() -> None:
519544
r"""
520545
Example
521546
8

Diff for: data_structures/binary_tree/lazy_segment_tree.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import math
4+
from typing import List, Union
45

56

67
class SegmentTree:
@@ -37,7 +38,7 @@ def right(self, idx: int) -> int:
3738
return idx * 2 + 1
3839

3940
def build(
40-
self, idx: int, left_element: int, right_element: int, A: list[int]
41+
self, idx: int, left_element: int, right_element: int, A: List[int]
4142
) -> None:
4243
if left_element == right_element:
4344
self.segment_tree[idx] = A[left_element - 1]
@@ -88,7 +89,7 @@ def update(
8889
# query with O(lg n)
8990
def query(
9091
self, idx: int, left_element: int, right_element: int, a: int, b: int
91-
) -> int:
92+
) -> Union[int, float]:
9293
"""
9394
query(1, 1, size, a, b) for query max of [a,b]
9495
>>> A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]
@@ -118,8 +119,8 @@ def query(
118119
q2 = self.query(self.right(idx), mid + 1, right_element, a, b)
119120
return max(q1, q2)
120121

121-
def __str__(self) -> None:
122-
return [self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)]
122+
def __str__(self) -> str:
123+
return str([self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)])
123124

124125

125126
if __name__ == "__main__":

0 commit comments

Comments
 (0)