@@ -107,14 +107,12 @@ def __iter__(self) -> Iterator[int]:
107
107
"""
108
108
>>> list(Node(0))
109
109
[0]
110
- >>> list(Node(0, Node(-1), Node(1)))
110
+ >>> list(Node(0, Node(-1), Node(1), None ))
111
111
[-1, 0, 1]
112
112
"""
113
- if self .left :
114
- yield from self .left
113
+ yield from self .left or []
115
114
yield self .value
116
- if self .right :
117
- yield from self .right
115
+ yield from self .right or []
118
116
119
117
def __repr__ (self ) -> str :
120
118
from pprint import pformat
@@ -145,10 +143,10 @@ def __str__(self) -> str:
145
143
return str (self .root )
146
144
147
145
def __reassign_nodes (self , node : Node , new_children : Node | None ) -> None :
148
- if new_children is not None :
146
+ if new_children is not None : # reset its kids
149
147
new_children .parent = node .parent
150
- if node .parent is not None :
151
- if node .is_right :
148
+ if node .parent is not None : # reset its parent
149
+ if node .is_right : # If it is the right child
152
150
node .parent .right = new_children
153
151
else :
154
152
node .parent .left = new_children
@@ -169,37 +167,37 @@ def empty(self) -> bool:
169
167
"""
170
168
return not self .root
171
169
172
- def __insert (self , value : int ) -> None :
170
+ def __insert (self , value ) -> None :
173
171
"""
174
172
Insert a new node in Binary Search Tree with value label
175
173
"""
176
- new_node = Node (value )
177
- if self .empty ():
178
- self .root = new_node
179
- else :
180
- parent_node = self .root
181
- while True :
182
- if value < parent_node .value :
174
+ new_node = Node (value ) # create a new Node
175
+ if self .empty (): # if Tree is empty
176
+ self .root = new_node # set its root
177
+ else : # Tree is not empty
178
+ parent_node = self .root # from root
179
+ if parent_node is None :
180
+ return
181
+ while True : # While we don't get to a leaf
182
+ if value < parent_node .value : # We go left
183
183
if parent_node .left is None :
184
- parent_node .left = new_node
185
- new_node .parent = parent_node
184
+ parent_node .left = new_node # We insert the new node in a leaf
186
185
break
187
186
else :
188
187
parent_node = parent_node .left
188
+ elif parent_node .right is None :
189
+ parent_node .right = new_node
190
+ break
189
191
else :
190
- if parent_node .right is None :
191
- parent_node .right = new_node
192
- new_node .parent = parent_node
193
- break
194
- else :
195
- parent_node = parent_node .right
192
+ parent_node = parent_node .right
193
+ new_node .parent = parent_node
196
194
197
- def insert (self , * values : int ) -> Self :
195
+ def insert (self , * values ) -> Self :
198
196
for value in values :
199
197
self .__insert (value )
200
198
return self
201
199
202
- def search (self , value : int ) -> Node | None :
200
+ def search (self , value ) -> Node | None :
203
201
"""
204
202
>>> tree = BinarySearchTree().insert(10, 20, 30, 40, 50)
205
203
>>> tree.search(10)
@@ -223,32 +221,37 @@ def search(self, value: int) -> Node | None:
223
221
...
224
222
IndexError: Warning: Tree is empty! please use another.
225
223
"""
224
+
226
225
if self .empty ():
227
226
raise IndexError ("Warning: Tree is empty! please use another." )
228
- node = self .root
229
- while node is not None and node .value != value :
230
- node = node .left if value < node .value else node .right
231
- return node
227
+ else :
228
+ node = self .root
229
+ # use lazy evaluation here to avoid NoneType Attribute error
230
+ while node is not None and node .value is not value :
231
+ node = node .left if value < node .value else node .right
232
+ return node
232
233
233
234
def get_max (self , node : Node | None = None ) -> Node | None :
234
235
"""
235
236
We go deep on the right branch
236
237
237
238
>>> BinarySearchTree().insert(10, 20, 30, 40, 50).get_max()
238
239
50
239
- >>> BinarySearchTree().insert(-5, -1, 0, -0.3, -4.5).get_max()
240
- {'0': (-0.3, None)}
240
+ >>> BinarySearchTree().insert(-5, -1, 0.1 , -0.3, -4.5).get_max()
241
+ {'0.1 ': (-0.3, None)}
241
242
>>> BinarySearchTree().insert(1, 78.3, 30, 74.0, 1).get_max()
242
243
{'78.3': ({'30': (1, 74.0)}, None)}
243
244
>>> BinarySearchTree().insert(1, 783, 30, 740, 1).get_max()
244
245
{'783': ({'30': (1, 740)}, None)}
245
246
"""
246
247
if node is None :
247
- if self .empty () :
248
+ if self .root is None :
248
249
return None
249
250
node = self .root
250
- while node .right is not None :
251
- node = node .right
251
+
252
+ if not self .empty ():
253
+ while node .right is not None :
254
+ node = node .right
252
255
return node
253
256
254
257
def get_min (self , node : Node | None = None ) -> Node | None :
@@ -265,47 +268,54 @@ def get_min(self, node: Node | None = None) -> Node | None:
265
268
{'1': (None, {'783': ({'30': (1, 740)}, None)})}
266
269
"""
267
270
if node is None :
268
- if self .empty ():
269
- return None
270
271
node = self .root
271
- while node .left is not None :
272
- node = node .left
272
+ if self .root is None :
273
+ return None
274
+ if not self .empty ():
275
+ node = self .root
276
+ while node .left is not None :
277
+ node = node .left
273
278
return node
274
279
275
280
def remove (self , value : int ) -> None :
281
+ # Look for the node with that label
276
282
node = self .search (value )
277
283
if node is None :
278
- raise ValueError (f"Value { value } not found" )
284
+ msg = f"Value { value } not found"
285
+ raise ValueError (msg )
279
286
280
- if node .left is None and node .right is None :
287
+ if node .left is None and node .right is None : # If it has no children
281
288
self .__reassign_nodes (node , None )
282
- elif node .left is None :
289
+ elif node .left is None : # Has only right children
283
290
self .__reassign_nodes (node , node .right )
284
- elif node .right is None :
291
+ elif node .right is None : # Has only left children
285
292
self .__reassign_nodes (node , node .left )
286
293
else :
287
- predecessor = self .get_max (node .left )
288
- if predecessor :
289
- self .remove (predecessor .value )
290
- node .value = predecessor .value
291
-
292
- def preorder_traverse (self , node : Node | None ) -> Iterable [Node ]:
294
+ predecessor = self .get_max (
295
+ node .left
296
+ ) # Gets the max value of the left branch
297
+ self .remove (predecessor .value ) # type: ignore[union-attr]
298
+ node .value = (
299
+ predecessor .value # type: ignore[union-attr]
300
+ ) # Assigns the value to the node to delete and keep tree structure
301
+
302
+ def preorder_traverse (self , node : Node | None ) -> Iterable :
293
303
if node is not None :
294
- yield node
304
+ yield node # Preorder Traversal
295
305
yield from self .preorder_traverse (node .left )
296
306
yield from self .preorder_traverse (node .right )
297
307
298
308
def traversal_tree (self , traversal_function = None ) -> Any :
299
309
"""
300
- This function traverses the tree.
301
- You can pass a function to traverse the tree as needed by client code
310
+ This function traversal the tree.
311
+ You can pass a function to traversal the tree as needed by client code
302
312
"""
303
313
if traversal_function is None :
304
- return list ( self .preorder_traverse (self .root ) )
314
+ return self .preorder_traverse (self .root )
305
315
else :
306
316
return traversal_function (self .root )
307
317
308
- def inorder (self , arr : list [ int ] , node : Node | None ) -> None :
318
+ def inorder (self , arr : list , node : Node | None ) -> None :
309
319
"""Perform an inorder traversal and append values of the nodes to
310
320
a list named arr"""
311
321
if node :
@@ -316,10 +326,8 @@ def inorder(self, arr: list[int], node: Node | None) -> None:
316
326
def find_kth_smallest (self , k : int , node : Node ) -> int :
317
327
"""Return the kth smallest element in a binary search tree"""
318
328
arr : list [int ] = []
319
- self .inorder (arr , node )
320
- if 0 < k <= len (arr ):
321
- return arr [k - 1 ]
322
- raise IndexError ("k is out of bounds" )
329
+ self .inorder (arr , node ) # append all values to list using inorder traversal
330
+ return arr [k - 1 ]
323
331
324
332
325
333
def inorder (curr_node : Node | None ) -> list [Node ]:
@@ -338,4 +346,11 @@ def postorder(curr_node: Node | None) -> list[Node]:
338
346
"""
339
347
node_list = []
340
348
if curr_node is not None :
341
- node_list = postorder (curr_node .left )
349
+ node_list = postorder (curr_node .left ) + postorder (curr_node .right ) + [curr_node ]
350
+ return node_list
351
+
352
+
353
+ if __name__ == "__main__" :
354
+ import doctest
355
+
356
+ doctest .testmod (verbose = True )
0 commit comments