2
2
A binary search Tree
3
3
"""
4
4
5
+ from collections .abc import Iterable
6
+ from typing import Any
7
+
5
8
6
9
class Node :
7
- def __init__ (self , value , parent ):
10
+ def __init__ (self , value : int | None = None ):
8
11
self .value = value
9
- self .parent = parent # Added in order to delete a node easier
10
- self .left = None
11
- self .right = None
12
+ self .parent : Node | None = None # Added in order to delete a node easier
13
+ self .left : Node | None = None
14
+ self .right : Node | None = None
12
15
13
- def __repr__ (self ):
16
+ def __repr__ (self ) -> str :
14
17
from pprint import pformat
15
18
16
19
if self .left is None and self .right is None :
@@ -19,16 +22,16 @@ def __repr__(self):
19
22
20
23
21
24
class BinarySearchTree :
22
- def __init__ (self , root = None ):
25
+ def __init__ (self , root : Node | None = None ):
23
26
self .root = root
24
27
25
- def __str__ (self ):
28
+ def __str__ (self ) -> str :
26
29
"""
27
30
Return a string of all the Nodes using in order traversal
28
31
"""
29
32
return str (self .root )
30
33
31
- def __reassign_nodes (self , node , new_children ) :
34
+ def __reassign_nodes (self , node : Node , new_children : Node | None ) -> None :
32
35
if new_children is not None : # reset its kids
33
36
new_children .parent = node .parent
34
37
if node .parent is not None : # reset its parent
@@ -37,23 +40,27 @@ def __reassign_nodes(self, node, new_children):
37
40
else :
38
41
node .parent .left = new_children
39
42
else :
40
- self .root = new_children
43
+ self .root = None
41
44
42
- def is_right (self , node ):
43
- return node == node .parent .right
45
+ def is_right (self , node : Node ) -> bool :
46
+ if node .parent and node .parent .right :
47
+ return node == node .parent .right
48
+ return False
44
49
45
- def empty (self ):
50
+ def empty (self ) -> bool :
46
51
return self .root is None
47
52
48
- def __insert (self , value ):
53
+ def __insert (self , value ) -> None :
49
54
"""
50
55
Insert a new node in Binary Search Tree with value label
51
56
"""
52
- new_node = Node (value , None ) # create a new Node
57
+ new_node = Node (value ) # create a new Node
53
58
if self .empty (): # if Tree is empty
54
59
self .root = new_node # set its root
55
60
else : # Tree is not empty
56
61
parent_node = self .root # from root
62
+ if parent_node is None :
63
+ return None
57
64
while True : # While we don't get to a leaf
58
65
if value < parent_node .value : # We go left
59
66
if parent_node .left is None :
@@ -69,12 +76,11 @@ def __insert(self, value):
69
76
parent_node = parent_node .right
70
77
new_node .parent = parent_node
71
78
72
- def insert (self , * values ):
79
+ def insert (self , * values ) -> None :
73
80
for value in values :
74
81
self .__insert (value )
75
- return self
76
82
77
- def search (self , value ):
83
+ def search (self , value ) -> Node | None :
78
84
if self .empty ():
79
85
raise IndexError ("Warning: Tree is empty! please use another." )
80
86
else :
@@ -84,30 +90,35 @@ def search(self, value):
84
90
node = node .left if value < node .value else node .right
85
91
return node
86
92
87
- def get_max (self , node = None ):
93
+ def get_max (self , node : Node | None = None ) -> Node | None :
88
94
"""
89
95
We go deep on the right branch
90
96
"""
91
97
if node is None :
98
+ if self .root is None :
99
+ return None
92
100
node = self .root
101
+
93
102
if not self .empty ():
94
103
while node .right is not None :
95
104
node = node .right
96
105
return node
97
106
98
- def get_min (self , node = None ):
107
+ def get_min (self , node : Node | None = None ) -> Node | None :
99
108
"""
100
109
We go deep on the left branch
101
110
"""
102
111
if node is None :
103
112
node = self .root
113
+ if self .root is None :
114
+ return None
104
115
if not self .empty ():
105
116
node = self .root
106
117
while node .left is not None :
107
118
node = node .left
108
119
return node
109
120
110
- def remove (self , value ) :
121
+ def remove (self , value : int ) -> None :
111
122
node = self .search (value ) # Look for the node with that label
112
123
if node is not None :
113
124
if node .left is None and node .right is None : # If it has no children
@@ -120,18 +131,18 @@ def remove(self, value):
120
131
tmp_node = self .get_max (
121
132
node .left
122
133
) # Gets the max value of the left branch
123
- self .remove (tmp_node .value )
134
+ self .remove (tmp_node .value ) # type: ignore
124
135
node .value = (
125
- tmp_node .value
136
+ tmp_node .value # type: ignore
126
137
) # Assigns the value to the node to delete and keep tree structure
127
138
128
- def preorder_traverse (self , node ) :
139
+ def preorder_traverse (self , node : Node | None ) -> Iterable :
129
140
if node is not None :
130
141
yield node # Preorder Traversal
131
142
yield from self .preorder_traverse (node .left )
132
143
yield from self .preorder_traverse (node .right )
133
144
134
- def traversal_tree (self , traversal_function = None ):
145
+ def traversal_tree (self , traversal_function = None ) -> Any :
135
146
"""
136
147
This function traversal the tree.
137
148
You can pass a function to traversal the tree as needed by client code
@@ -141,7 +152,7 @@ def traversal_tree(self, traversal_function=None):
141
152
else :
142
153
return traversal_function (self .root )
143
154
144
- def inorder (self , arr : list , node : Node ) :
155
+ def inorder (self , arr : list , node : Node | None ) -> None :
145
156
"""Perform an inorder traversal and append values of the nodes to
146
157
a list named arr"""
147
158
if node :
@@ -151,12 +162,12 @@ def inorder(self, arr: list, node: Node):
151
162
152
163
def find_kth_smallest (self , k : int , node : Node ) -> int :
153
164
"""Return the kth smallest element in a binary search tree"""
154
- arr : list = []
165
+ arr : list [ int ] = []
155
166
self .inorder (arr , node ) # append all values to list using inorder traversal
156
167
return arr [k - 1 ]
157
168
158
169
159
- def postorder (curr_node ) :
170
+ def postorder (curr_node : Node | None ) -> list [ Node ] :
160
171
"""
161
172
postOrder (left, right, self)
162
173
"""
@@ -166,7 +177,7 @@ def postorder(curr_node):
166
177
return node_list
167
178
168
179
169
- def binary_search_tree ():
180
+ def binary_search_tree () -> None :
170
181
r"""
171
182
Example
172
183
8
@@ -177,7 +188,8 @@ def binary_search_tree():
177
188
/ \ /
178
189
4 7 13
179
190
180
- >>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
191
+ >>> t = BinarySearchTree()
192
+ >>> t.insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
181
193
>>> print(" ".join(repr(i.value) for i in t.traversal_tree()))
182
194
8 3 1 6 4 7 10 14 13
183
195
>>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder)))
@@ -206,8 +218,8 @@ def binary_search_tree():
206
218
print ("The value -1 doesn't exist" )
207
219
208
220
if not t .empty ():
209
- print ("Max Value: " , t .get_max ().value )
210
- print ("Min Value: " , t .get_min ().value )
221
+ print ("Max Value: " , t .get_max ().value ) # type: ignore
222
+ print ("Min Value: " , t .get_min ().value ) # type: ignore
211
223
212
224
for i in testlist :
213
225
t .remove (i )
@@ -217,5 +229,4 @@ def binary_search_tree():
217
229
if __name__ == "__main__" :
218
230
import doctest
219
231
220
- doctest .testmod ()
221
- # binary_search_tree()
232
+ doctest .testmod (verbose = True )
0 commit comments