Skip to content

Commit 17ab836

Browse files
committed
shubhamvk03
1 parent 8ff00a8 commit 17ab836

File tree

1 file changed

+74
-59
lines changed

1 file changed

+74
-59
lines changed

data_structures/binary_tree/binary_search_tree.py

+74-59
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,12 @@ def __iter__(self) -> Iterator[int]:
107107
"""
108108
>>> list(Node(0))
109109
[0]
110-
>>> list(Node(0, Node(-1), Node(1)))
110+
>>> list(Node(0, Node(-1), Node(1), None))
111111
[-1, 0, 1]
112112
"""
113-
if self.left:
114-
yield from self.left
113+
yield from self.left or []
115114
yield self.value
116-
if self.right:
117-
yield from self.right
115+
yield from self.right or []
118116

119117
def __repr__(self) -> str:
120118
from pprint import pformat
@@ -145,10 +143,10 @@ def __str__(self) -> str:
145143
return str(self.root)
146144

147145
def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
148-
if new_children is not None:
146+
if new_children is not None: # reset its kids
149147
new_children.parent = node.parent
150-
if node.parent is not None:
151-
if node.is_right:
148+
if node.parent is not None: # reset its parent
149+
if node.is_right: # If it is the right child
152150
node.parent.right = new_children
153151
else:
154152
node.parent.left = new_children
@@ -169,37 +167,37 @@ def empty(self) -> bool:
169167
"""
170168
return not self.root
171169

172-
def __insert(self, value: int) -> None:
170+
def __insert(self, value) -> None:
173171
"""
174172
Insert a new node in Binary Search Tree with value label
175173
"""
176-
new_node = Node(value)
177-
if self.empty():
178-
self.root = new_node
179-
else:
180-
parent_node = self.root
181-
while True:
182-
if value < parent_node.value:
174+
new_node = Node(value) # create a new Node
175+
if self.empty(): # if Tree is empty
176+
self.root = new_node # set its root
177+
else: # Tree is not empty
178+
parent_node = self.root # from root
179+
if parent_node is None:
180+
return
181+
while True: # While we don't get to a leaf
182+
if value < parent_node.value: # We go left
183183
if parent_node.left is None:
184-
parent_node.left = new_node
185-
new_node.parent = parent_node
184+
parent_node.left = new_node # We insert the new node in a leaf
186185
break
187186
else:
188187
parent_node = parent_node.left
188+
elif parent_node.right is None:
189+
parent_node.right = new_node
190+
break
189191
else:
190-
if parent_node.right is None:
191-
parent_node.right = new_node
192-
new_node.parent = parent_node
193-
break
194-
else:
195-
parent_node = parent_node.right
192+
parent_node = parent_node.right
193+
new_node.parent = parent_node
196194

197-
def insert(self, *values: int) -> Self:
195+
def insert(self, *values) -> Self:
198196
for value in values:
199197
self.__insert(value)
200198
return self
201199

202-
def search(self, value: int) -> Node | None:
200+
def search(self, value) -> Node | None:
203201
"""
204202
>>> tree = BinarySearchTree().insert(10, 20, 30, 40, 50)
205203
>>> tree.search(10)
@@ -223,32 +221,37 @@ def search(self, value: int) -> Node | None:
223221
...
224222
IndexError: Warning: Tree is empty! please use another.
225223
"""
224+
226225
if self.empty():
227226
raise IndexError("Warning: Tree is empty! please use another.")
228-
node = self.root
229-
while node is not None and node.value != value:
230-
node = node.left if value < node.value else node.right
231-
return node
227+
else:
228+
node = self.root
229+
# use lazy evaluation here to avoid NoneType Attribute error
230+
while node is not None and node.value is not value:
231+
node = node.left if value < node.value else node.right
232+
return node
232233

233234
def get_max(self, node: Node | None = None) -> Node | None:
234235
"""
235236
We go deep on the right branch
236237
237238
>>> BinarySearchTree().insert(10, 20, 30, 40, 50).get_max()
238239
50
239-
>>> BinarySearchTree().insert(-5, -1, 0, -0.3, -4.5).get_max()
240-
{'0': (-0.3, None)}
240+
>>> BinarySearchTree().insert(-5, -1, 0.1, -0.3, -4.5).get_max()
241+
{'0.1': (-0.3, None)}
241242
>>> BinarySearchTree().insert(1, 78.3, 30, 74.0, 1).get_max()
242243
{'78.3': ({'30': (1, 74.0)}, None)}
243244
>>> BinarySearchTree().insert(1, 783, 30, 740, 1).get_max()
244245
{'783': ({'30': (1, 740)}, None)}
245246
"""
246247
if node is None:
247-
if self.empty():
248+
if self.root is None:
248249
return None
249250
node = self.root
250-
while node.right is not None:
251-
node = node.right
251+
252+
if not self.empty():
253+
while node.right is not None:
254+
node = node.right
252255
return node
253256

254257
def get_min(self, node: Node | None = None) -> Node | None:
@@ -265,47 +268,54 @@ def get_min(self, node: Node | None = None) -> Node | None:
265268
{'1': (None, {'783': ({'30': (1, 740)}, None)})}
266269
"""
267270
if node is None:
268-
if self.empty():
269-
return None
270271
node = self.root
271-
while node.left is not None:
272-
node = node.left
272+
if self.root is None:
273+
return None
274+
if not self.empty():
275+
node = self.root
276+
while node.left is not None:
277+
node = node.left
273278
return node
274279

275280
def remove(self, value: int) -> None:
281+
# Look for the node with that label
276282
node = self.search(value)
277283
if node is None:
278-
raise ValueError(f"Value {value} not found")
284+
msg = f"Value {value} not found"
285+
raise ValueError(msg)
279286

280-
if node.left is None and node.right is None:
287+
if node.left is None and node.right is None: # If it has no children
281288
self.__reassign_nodes(node, None)
282-
elif node.left is None:
289+
elif node.left is None: # Has only right children
283290
self.__reassign_nodes(node, node.right)
284-
elif node.right is None:
291+
elif node.right is None: # Has only left children
285292
self.__reassign_nodes(node, node.left)
286293
else:
287-
predecessor = self.get_max(node.left)
288-
if predecessor:
289-
self.remove(predecessor.value)
290-
node.value = predecessor.value
291-
292-
def preorder_traverse(self, node: Node | None) -> Iterable[Node]:
294+
predecessor = self.get_max(
295+
node.left
296+
) # Gets the max value of the left branch
297+
self.remove(predecessor.value) # type: ignore[union-attr]
298+
node.value = (
299+
predecessor.value # type: ignore[union-attr]
300+
) # Assigns the value to the node to delete and keep tree structure
301+
302+
def preorder_traverse(self, node: Node | None) -> Iterable:
293303
if node is not None:
294-
yield node
304+
yield node # Preorder Traversal
295305
yield from self.preorder_traverse(node.left)
296306
yield from self.preorder_traverse(node.right)
297307

298308
def traversal_tree(self, traversal_function=None) -> Any:
299309
"""
300-
This function traverses the tree.
301-
You can pass a function to traverse the tree as needed by client code
310+
This function traversal the tree.
311+
You can pass a function to traversal the tree as needed by client code
302312
"""
303313
if traversal_function is None:
304-
return list(self.preorder_traverse(self.root))
314+
return self.preorder_traverse(self.root)
305315
else:
306316
return traversal_function(self.root)
307317

308-
def inorder(self, arr: list[int], node: Node | None) -> None:
318+
def inorder(self, arr: list, node: Node | None) -> None:
309319
"""Perform an inorder traversal and append values of the nodes to
310320
a list named arr"""
311321
if node:
@@ -316,10 +326,8 @@ def inorder(self, arr: list[int], node: Node | None) -> None:
316326
def find_kth_smallest(self, k: int, node: Node) -> int:
317327
"""Return the kth smallest element in a binary search tree"""
318328
arr: list[int] = []
319-
self.inorder(arr, node)
320-
if 0 < k <= len(arr):
321-
return arr[k - 1]
322-
raise IndexError("k is out of bounds")
329+
self.inorder(arr, node) # append all values to list using inorder traversal
330+
return arr[k - 1]
323331

324332

325333
def inorder(curr_node: Node | None) -> list[Node]:
@@ -338,4 +346,11 @@ def postorder(curr_node: Node | None) -> list[Node]:
338346
"""
339347
node_list = []
340348
if curr_node is not None:
341-
node_list = postorder(curr_node.left)
349+
node_list = postorder(curr_node.left) + postorder(curr_node.right) + [curr_node]
350+
return node_list
351+
352+
353+
if __name__ == "__main__":
354+
import doctest
355+
356+
doctest.testmod(verbose=True)

0 commit comments

Comments
 (0)