8
8
python binary_search_tree_recursive.py
9
9
"""
10
10
import unittest
11
+ from typing import Iterator , Optional
11
12
12
13
13
14
class Node :
14
- def __init__ (self , label : int , parent ) :
15
+ def __init__ (self , label : int , parent : Optional [ "Node" ]) -> None :
15
16
self .label = label
16
17
self .parent = parent
17
- self .left = None
18
- self .right = None
18
+ self .left : Optional [ Node ] = None
19
+ self .right : Optional [ Node ] = None
19
20
20
21
21
22
class BinarySearchTree :
22
- def __init__ (self ):
23
- self .root = None
23
+ def __init__ (self ) -> None :
24
+ self .root : Optional [ Node ] = None
24
25
25
- def empty (self ):
26
+ def empty (self ) -> None :
26
27
"""
27
28
Empties the tree
28
29
@@ -46,7 +47,7 @@ def is_empty(self) -> bool:
46
47
"""
47
48
return self .root is None
48
49
49
- def put (self , label : int ):
50
+ def put (self , label : int ) -> None :
50
51
"""
51
52
Put a new node in the tree
52
53
@@ -65,7 +66,9 @@ def put(self, label: int):
65
66
"""
66
67
self .root = self ._put (self .root , label )
67
68
68
- def _put (self , node : Node , label : int , parent : Node = None ) -> Node :
69
+ def _put (
70
+ self , node : Optional [Node ], label : int , parent : Optional [Node ] = None
71
+ ) -> Node :
69
72
if node is None :
70
73
node = Node (label , parent )
71
74
else :
@@ -95,7 +98,7 @@ def search(self, label: int) -> Node:
95
98
"""
96
99
return self ._search (self .root , label )
97
100
98
- def _search (self , node : Node , label : int ) -> Node :
101
+ def _search (self , node : Optional [ Node ] , label : int ) -> Node :
99
102
if node is None :
100
103
raise Exception (f"Node with label { label } does not exist" )
101
104
else :
@@ -106,7 +109,7 @@ def _search(self, node: Node, label: int) -> Node:
106
109
107
110
return node
108
111
109
- def remove (self , label : int ):
112
+ def remove (self , label : int ) -> None :
110
113
"""
111
114
Removes a node in the tree
112
115
@@ -122,22 +125,22 @@ def remove(self, label: int):
122
125
Exception: Node with label 3 does not exist
123
126
"""
124
127
node = self .search (label )
125
- if not node .right and not node .left :
126
- self ._reassign_nodes (node , None )
127
- elif not node .right and node .left :
128
- self ._reassign_nodes (node , node .left )
129
- elif node .right and not node .left :
130
- self ._reassign_nodes (node , node .right )
131
- else :
128
+ if node .right and node .left :
132
129
lowest_node = self ._get_lowest_node (node .right )
133
130
lowest_node .left = node .left
134
131
lowest_node .right = node .right
135
132
node .left .parent = lowest_node
136
133
if node .right :
137
134
node .right .parent = lowest_node
138
135
self ._reassign_nodes (node , lowest_node )
136
+ elif not node .right and node .left :
137
+ self ._reassign_nodes (node , node .left )
138
+ elif node .right and not node .left :
139
+ self ._reassign_nodes (node , node .right )
140
+ else :
141
+ self ._reassign_nodes (node , None )
139
142
140
- def _reassign_nodes (self , node : Node , new_children : Node ) :
143
+ def _reassign_nodes (self , node : Node , new_children : Optional [ Node ]) -> None :
141
144
if new_children :
142
145
new_children .parent = node .parent
143
146
@@ -192,7 +195,7 @@ def get_max_label(self) -> int:
192
195
>>> t.get_max_label()
193
196
10
194
197
"""
195
- if self .is_empty () :
198
+ if self .root is None :
196
199
raise Exception ("Binary search tree is empty" )
197
200
198
201
node = self .root
@@ -216,7 +219,7 @@ def get_min_label(self) -> int:
216
219
>>> t.get_min_label()
217
220
8
218
221
"""
219
- if self .is_empty () :
222
+ if self .root is None :
220
223
raise Exception ("Binary search tree is empty" )
221
224
222
225
node = self .root
@@ -225,7 +228,7 @@ def get_min_label(self) -> int:
225
228
226
229
return node .label
227
230
228
- def inorder_traversal (self ) -> list :
231
+ def inorder_traversal (self ) -> Iterator [ Node ] :
229
232
"""
230
233
Return the inorder traversal of the tree
231
234
@@ -241,13 +244,13 @@ def inorder_traversal(self) -> list:
241
244
"""
242
245
return self ._inorder_traversal (self .root )
243
246
244
- def _inorder_traversal (self , node : Node ) -> list :
247
+ def _inorder_traversal (self , node : Optional [ Node ] ) -> Iterator [ Node ] :
245
248
if node is not None :
246
249
yield from self ._inorder_traversal (node .left )
247
250
yield node
248
251
yield from self ._inorder_traversal (node .right )
249
252
250
- def preorder_traversal (self ) -> list :
253
+ def preorder_traversal (self ) -> Iterator [ Node ] :
251
254
"""
252
255
Return the preorder traversal of the tree
253
256
@@ -263,7 +266,7 @@ def preorder_traversal(self) -> list:
263
266
"""
264
267
return self ._preorder_traversal (self .root )
265
268
266
- def _preorder_traversal (self , node : Node ) -> list :
269
+ def _preorder_traversal (self , node : Optional [ Node ] ) -> Iterator [ Node ] :
267
270
if node is not None :
268
271
yield node
269
272
yield from self ._preorder_traversal (node .left )
@@ -272,7 +275,7 @@ def _preorder_traversal(self, node: Node) -> list:
272
275
273
276
class BinarySearchTreeTest (unittest .TestCase ):
274
277
@staticmethod
275
- def _get_binary_search_tree ():
278
+ def _get_binary_search_tree () -> BinarySearchTree :
276
279
r"""
277
280
8
278
281
/ \
@@ -298,14 +301,15 @@ def _get_binary_search_tree():
298
301
299
302
return t
300
303
301
- def test_put (self ):
304
+ def test_put (self ) -> None :
302
305
t = BinarySearchTree ()
303
306
assert t .is_empty ()
304
307
305
308
t .put (8 )
306
309
r"""
307
310
8
308
311
"""
312
+ assert t .root is not None
309
313
assert t .root .parent is None
310
314
assert t .root .label == 8
311
315
@@ -315,6 +319,7 @@ def test_put(self):
315
319
\
316
320
10
317
321
"""
322
+ assert t .root .right is not None
318
323
assert t .root .right .parent == t .root
319
324
assert t .root .right .label == 10
320
325
@@ -324,6 +329,7 @@ def test_put(self):
324
329
/ \
325
330
3 10
326
331
"""
332
+ assert t .root .left is not None
327
333
assert t .root .left .parent == t .root
328
334
assert t .root .left .label == 3
329
335
@@ -335,6 +341,7 @@ def test_put(self):
335
341
\
336
342
6
337
343
"""
344
+ assert t .root .left .right is not None
338
345
assert t .root .left .right .parent == t .root .left
339
346
assert t .root .left .right .label == 6
340
347
@@ -346,13 +353,14 @@ def test_put(self):
346
353
/ \
347
354
1 6
348
355
"""
356
+ assert t .root .left .left is not None
349
357
assert t .root .left .left .parent == t .root .left
350
358
assert t .root .left .left .label == 1
351
359
352
360
with self .assertRaises (Exception ):
353
361
t .put (1 )
354
362
355
- def test_search (self ):
363
+ def test_search (self ) -> None :
356
364
t = self ._get_binary_search_tree ()
357
365
358
366
node = t .search (6 )
@@ -364,7 +372,7 @@ def test_search(self):
364
372
with self .assertRaises (Exception ):
365
373
t .search (2 )
366
374
367
- def test_remove (self ):
375
+ def test_remove (self ) -> None :
368
376
t = self ._get_binary_search_tree ()
369
377
370
378
t .remove (13 )
@@ -379,6 +387,9 @@ def test_remove(self):
379
387
\
380
388
5
381
389
"""
390
+ assert t .root is not None
391
+ assert t .root .right is not None
392
+ assert t .root .right .right is not None
382
393
assert t .root .right .right .right is None
383
394
assert t .root .right .right .left is None
384
395
@@ -394,6 +405,9 @@ def test_remove(self):
394
405
\
395
406
5
396
407
"""
408
+ assert t .root .left is not None
409
+ assert t .root .left .right is not None
410
+ assert t .root .left .right .left is not None
397
411
assert t .root .left .right .right is None
398
412
assert t .root .left .right .left .label == 4
399
413
@@ -407,6 +421,8 @@ def test_remove(self):
407
421
\
408
422
5
409
423
"""
424
+ assert t .root .left .left is not None
425
+ assert t .root .left .right .right is not None
410
426
assert t .root .left .left .label == 1
411
427
assert t .root .left .right .label == 4
412
428
assert t .root .left .right .right .label == 5
@@ -422,6 +438,7 @@ def test_remove(self):
422
438
/ \ \
423
439
1 5 14
424
440
"""
441
+ assert t .root is not None
425
442
assert t .root .left .label == 4
426
443
assert t .root .left .right .label == 5
427
444
assert t .root .left .left .label == 1
@@ -437,13 +454,15 @@ def test_remove(self):
437
454
/ \
438
455
1 14
439
456
"""
457
+ assert t .root .left is not None
458
+ assert t .root .left .left is not None
440
459
assert t .root .left .label == 5
441
460
assert t .root .left .right is None
442
461
assert t .root .left .left .label == 1
443
462
assert t .root .left .parent == t .root
444
463
assert t .root .left .left .parent == t .root .left
445
464
446
- def test_remove_2 (self ):
465
+ def test_remove_2 (self ) -> None :
447
466
t = self ._get_binary_search_tree ()
448
467
449
468
t .remove (3 )
@@ -456,6 +475,12 @@ def test_remove_2(self):
456
475
/ \ /
457
476
5 7 13
458
477
"""
478
+ assert t .root is not None
479
+ assert t .root .left is not None
480
+ assert t .root .left .left is not None
481
+ assert t .root .left .right is not None
482
+ assert t .root .left .right .left is not None
483
+ assert t .root .left .right .right is not None
459
484
assert t .root .left .label == 4
460
485
assert t .root .left .right .label == 6
461
486
assert t .root .left .left .label == 1
@@ -466,25 +491,25 @@ def test_remove_2(self):
466
491
assert t .root .left .left .parent == t .root .left
467
492
assert t .root .left .right .left .parent == t .root .left .right
468
493
469
- def test_empty (self ):
494
+ def test_empty (self ) -> None :
470
495
t = self ._get_binary_search_tree ()
471
496
t .empty ()
472
497
assert t .root is None
473
498
474
- def test_is_empty (self ):
499
+ def test_is_empty (self ) -> None :
475
500
t = self ._get_binary_search_tree ()
476
501
assert not t .is_empty ()
477
502
478
503
t .empty ()
479
504
assert t .is_empty ()
480
505
481
- def test_exists (self ):
506
+ def test_exists (self ) -> None :
482
507
t = self ._get_binary_search_tree ()
483
508
484
509
assert t .exists (6 )
485
510
assert not t .exists (- 1 )
486
511
487
- def test_get_max_label (self ):
512
+ def test_get_max_label (self ) -> None :
488
513
t = self ._get_binary_search_tree ()
489
514
490
515
assert t .get_max_label () == 14
@@ -493,7 +518,7 @@ def test_get_max_label(self):
493
518
with self .assertRaises (Exception ):
494
519
t .get_max_label ()
495
520
496
- def test_get_min_label (self ):
521
+ def test_get_min_label (self ) -> None :
497
522
t = self ._get_binary_search_tree ()
498
523
499
524
assert t .get_min_label () == 1
@@ -502,20 +527,20 @@ def test_get_min_label(self):
502
527
with self .assertRaises (Exception ):
503
528
t .get_min_label ()
504
529
505
- def test_inorder_traversal (self ):
530
+ def test_inorder_traversal (self ) -> None :
506
531
t = self ._get_binary_search_tree ()
507
532
508
533
inorder_traversal_nodes = [i .label for i in t .inorder_traversal ()]
509
534
assert inorder_traversal_nodes == [1 , 3 , 4 , 5 , 6 , 7 , 8 , 10 , 13 , 14 ]
510
535
511
- def test_preorder_traversal (self ):
536
+ def test_preorder_traversal (self ) -> None :
512
537
t = self ._get_binary_search_tree ()
513
538
514
539
preorder_traversal_nodes = [i .label for i in t .preorder_traversal ()]
515
540
assert preorder_traversal_nodes == [8 , 3 , 1 , 6 , 4 , 5 , 7 , 10 , 14 , 13 ]
516
541
517
542
518
- def binary_search_tree_example ():
543
+ def binary_search_tree_example () -> None :
519
544
r"""
520
545
Example
521
546
8
0 commit comments