Skip to content

Commit c94e215

Browse files
types: Update binary search tree typehints (#7197)
* types: Update binary search tree typehints * refactor: Don't return `self` in `:meth:insert` * test: Fix failing doctests * Apply suggestions from code review Co-authored-by: Dhruv Manilawala <[email protected]>
1 parent 553624f commit c94e215

File tree

1 file changed

+44
-33
lines changed

1 file changed

+44
-33
lines changed

Diff for: data_structures/binary_tree/binary_search_tree.py

+44-33
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
A binary search Tree
33
"""
44

5+
from collections.abc import Iterable
6+
from typing import Any
7+
58

69
class Node:
7-
def __init__(self, value, parent):
10+
def __init__(self, value: int | None = None):
811
self.value = value
9-
self.parent = parent # Added in order to delete a node easier
10-
self.left = None
11-
self.right = None
12+
self.parent: Node | None = None # Added in order to delete a node easier
13+
self.left: Node | None = None
14+
self.right: Node | None = None
1215

13-
def __repr__(self):
16+
def __repr__(self) -> str:
1417
from pprint import pformat
1518

1619
if self.left is None and self.right is None:
@@ -19,16 +22,16 @@ def __repr__(self):
1922

2023

2124
class BinarySearchTree:
22-
def __init__(self, root=None):
25+
def __init__(self, root: Node | None = None):
2326
self.root = root
2427

25-
def __str__(self):
28+
def __str__(self) -> str:
2629
"""
2730
Return a string of all the Nodes using in order traversal
2831
"""
2932
return str(self.root)
3033

31-
def __reassign_nodes(self, node, new_children):
34+
def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
3235
if new_children is not None: # reset its kids
3336
new_children.parent = node.parent
3437
if node.parent is not None: # reset its parent
@@ -37,23 +40,27 @@ def __reassign_nodes(self, node, new_children):
3740
else:
3841
node.parent.left = new_children
3942
else:
40-
self.root = new_children
43+
self.root = None
4144

42-
def is_right(self, node):
43-
return node == node.parent.right
45+
def is_right(self, node: Node) -> bool:
46+
if node.parent and node.parent.right:
47+
return node == node.parent.right
48+
return False
4449

45-
def empty(self):
50+
def empty(self) -> bool:
4651
return self.root is None
4752

48-
def __insert(self, value):
53+
def __insert(self, value) -> None:
4954
"""
5055
Insert a new node in Binary Search Tree with value label
5156
"""
52-
new_node = Node(value, None) # create a new Node
57+
new_node = Node(value) # create a new Node
5358
if self.empty(): # if Tree is empty
5459
self.root = new_node # set its root
5560
else: # Tree is not empty
5661
parent_node = self.root # from root
62+
if parent_node is None:
63+
return None
5764
while True: # While we don't get to a leaf
5865
if value < parent_node.value: # We go left
5966
if parent_node.left is None:
@@ -69,12 +76,11 @@ def __insert(self, value):
6976
parent_node = parent_node.right
7077
new_node.parent = parent_node
7178

72-
def insert(self, *values):
79+
def insert(self, *values) -> None:
7380
for value in values:
7481
self.__insert(value)
75-
return self
7682

77-
def search(self, value):
83+
def search(self, value) -> Node | None:
7884
if self.empty():
7985
raise IndexError("Warning: Tree is empty! please use another.")
8086
else:
@@ -84,30 +90,35 @@ def search(self, value):
8490
node = node.left if value < node.value else node.right
8591
return node
8692

87-
def get_max(self, node=None):
93+
def get_max(self, node: Node | None = None) -> Node | None:
8894
"""
8995
We go deep on the right branch
9096
"""
9197
if node is None:
98+
if self.root is None:
99+
return None
92100
node = self.root
101+
93102
if not self.empty():
94103
while node.right is not None:
95104
node = node.right
96105
return node
97106

98-
def get_min(self, node=None):
107+
def get_min(self, node: Node | None = None) -> Node | None:
99108
"""
100109
We go deep on the left branch
101110
"""
102111
if node is None:
103112
node = self.root
113+
if self.root is None:
114+
return None
104115
if not self.empty():
105116
node = self.root
106117
while node.left is not None:
107118
node = node.left
108119
return node
109120

110-
def remove(self, value):
121+
def remove(self, value: int) -> None:
111122
node = self.search(value) # Look for the node with that label
112123
if node is not None:
113124
if node.left is None and node.right is None: # If it has no children
@@ -120,18 +131,18 @@ def remove(self, value):
120131
tmp_node = self.get_max(
121132
node.left
122133
) # Gets the max value of the left branch
123-
self.remove(tmp_node.value)
134+
self.remove(tmp_node.value) # type: ignore
124135
node.value = (
125-
tmp_node.value
136+
tmp_node.value # type: ignore
126137
) # Assigns the value to the node to delete and keep tree structure
127138

128-
def preorder_traverse(self, node):
139+
def preorder_traverse(self, node: Node | None) -> Iterable:
129140
if node is not None:
130141
yield node # Preorder Traversal
131142
yield from self.preorder_traverse(node.left)
132143
yield from self.preorder_traverse(node.right)
133144

134-
def traversal_tree(self, traversal_function=None):
145+
def traversal_tree(self, traversal_function=None) -> Any:
135146
"""
136147
This function traversal the tree.
137148
You can pass a function to traversal the tree as needed by client code
@@ -141,7 +152,7 @@ def traversal_tree(self, traversal_function=None):
141152
else:
142153
return traversal_function(self.root)
143154

144-
def inorder(self, arr: list, node: Node):
155+
def inorder(self, arr: list, node: Node | None) -> None:
145156
"""Perform an inorder traversal and append values of the nodes to
146157
a list named arr"""
147158
if node:
@@ -151,12 +162,12 @@ def inorder(self, arr: list, node: Node):
151162

152163
def find_kth_smallest(self, k: int, node: Node) -> int:
153164
"""Return the kth smallest element in a binary search tree"""
154-
arr: list = []
165+
arr: list[int] = []
155166
self.inorder(arr, node) # append all values to list using inorder traversal
156167
return arr[k - 1]
157168

158169

159-
def postorder(curr_node):
170+
def postorder(curr_node: Node | None) -> list[Node]:
160171
"""
161172
postOrder (left, right, self)
162173
"""
@@ -166,7 +177,7 @@ def postorder(curr_node):
166177
return node_list
167178

168179

169-
def binary_search_tree():
180+
def binary_search_tree() -> None:
170181
r"""
171182
Example
172183
8
@@ -177,7 +188,8 @@ def binary_search_tree():
177188
/ \ /
178189
4 7 13
179190
180-
>>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
191+
>>> t = BinarySearchTree()
192+
>>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
181193
>>> print(" ".join(repr(i.value) for i in t.traversal_tree()))
182194
8 3 1 6 4 7 10 14 13
183195
>>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder)))
@@ -206,8 +218,8 @@ def binary_search_tree():
206218
print("The value -1 doesn't exist")
207219

208220
if not t.empty():
209-
print("Max Value: ", t.get_max().value)
210-
print("Min Value: ", t.get_min().value)
221+
print("Max Value: ", t.get_max().value) # type: ignore
222+
print("Min Value: ", t.get_min().value) # type: ignore
211223

212224
for i in testlist:
213225
t.remove(i)
@@ -217,5 +229,4 @@ def binary_search_tree():
217229
if __name__ == "__main__":
218230
import doctest
219231

220-
doctest.testmod()
221-
# binary_search_tree()
232+
doctest.testmod(verbose=True)

0 commit comments

Comments
 (0)