Skip to content

Commit d1132c4

Browse files
MarkHersheygithub-actionsdhruvmanila
authored andcommitted
[mypy] Add type hints and docstrings to heap.py (TheAlgorithms#3013)
* Add type hints and docstrings to heap.py - Add type hints - Add docstrings - Add explanatory comments - Improve code readability - Change to use f-string * Fix import sorting * fixup! Format Python code with psf/black push * Fix static type error * Fix failing test * Fix type hints * Add return annotation Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com> Co-authored-by: Dhruv Manilawala <[email protected]>
1 parent b45fd09 commit d1132c4

File tree

1 file changed

+107
-79
lines changed

1 file changed

+107
-79
lines changed

Diff for: data_structures/heap/heap.py

+107-79
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,138 @@
1-
#!/usr/bin/python3
1+
from typing import Iterable, List, Optional
22

33

44
class Heap:
5-
"""
5+
"""A Max Heap Implementation
6+
67
>>> unsorted = [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5]
78
>>> h = Heap()
8-
>>> h.build_heap(unsorted)
9-
>>> h.display()
9+
>>> h.build_max_heap(unsorted)
10+
>>> print(h)
1011
[209, 201, 25, 103, 107, 15, 1, 9, 7, 11, 5]
1112
>>>
12-
>>> h.get_max()
13+
>>> h.extract_max()
1314
209
14-
>>> h.display()
15+
>>> print(h)
1516
[201, 107, 25, 103, 11, 15, 1, 9, 7, 5]
1617
>>>
1718
>>> h.insert(100)
18-
>>> h.display()
19+
>>> print(h)
1920
[201, 107, 25, 103, 100, 15, 1, 9, 7, 5, 11]
2021
>>>
2122
>>> h.heap_sort()
22-
>>> h.display()
23+
>>> print(h)
2324
[1, 5, 7, 9, 11, 15, 25, 100, 103, 107, 201]
24-
>>>
2525
"""
2626

27-
def __init__(self):
28-
self.h = []
29-
self.curr_size = 0
27+
def __init__(self) -> None:
28+
self.h: List[float] = []
29+
self.heap_size: int = 0
30+
31+
def __repr__(self) -> str:
32+
return str(self.h)
3033

31-
def get_left_child_index(self, i):
32-
left_child_index = 2 * i + 1
33-
if left_child_index < self.curr_size:
34+
def parent_index(self, child_idx: int) -> Optional[int]:
35+
""" return the parent index of given child """
36+
if child_idx > 0:
37+
return (child_idx - 1) // 2
38+
return None
39+
40+
def left_child_idx(self, parent_idx: int) -> Optional[int]:
41+
"""
42+
return the left child index if the left child exists.
43+
if not, return None.
44+
"""
45+
left_child_index = 2 * parent_idx + 1
46+
if left_child_index < self.heap_size:
3447
return left_child_index
3548
return None
3649

37-
def get_right_child(self, i):
38-
right_child_index = 2 * i + 2
39-
if right_child_index < self.curr_size:
50+
def right_child_idx(self, parent_idx: int) -> Optional[int]:
51+
"""
52+
return the right child index if the right child exists.
53+
if not, return None.
54+
"""
55+
right_child_index = 2 * parent_idx + 2
56+
if right_child_index < self.heap_size:
4057
return right_child_index
4158
return None
4259

43-
def max_heapify(self, index):
44-
if index < self.curr_size:
45-
largest = index
46-
lc = self.get_left_child_index(index)
47-
rc = self.get_right_child(index)
48-
if lc is not None and self.h[lc] > self.h[largest]:
49-
largest = lc
50-
if rc is not None and self.h[rc] > self.h[largest]:
51-
largest = rc
52-
if largest != index:
53-
self.h[largest], self.h[index] = self.h[index], self.h[largest]
54-
self.max_heapify(largest)
55-
56-
def build_heap(self, collection):
57-
self.curr_size = len(collection)
60+
def max_heapify(self, index: int) -> None:
61+
"""
62+
correct a single violation of the heap property in a subtree's root.
63+
"""
64+
if index < self.heap_size:
65+
violation: int = index
66+
left_child = self.left_child_idx(index)
67+
right_child = self.right_child_idx(index)
68+
# check which child is larger than its parent
69+
if left_child is not None and self.h[left_child] > self.h[violation]:
70+
violation = left_child
71+
if right_child is not None and self.h[right_child] > self.h[violation]:
72+
violation = right_child
73+
# if violation indeed exists
74+
if violation != index:
75+
# swap to fix the violation
76+
self.h[violation], self.h[index] = self.h[index], self.h[violation]
77+
# fix the subsequent violation recursively if any
78+
self.max_heapify(violation)
79+
80+
def build_max_heap(self, collection: Iterable[float]) -> None:
81+
""" build max heap from an unsorted array"""
5882
self.h = list(collection)
59-
if self.curr_size <= 1:
60-
return
61-
for i in range(self.curr_size // 2 - 1, -1, -1):
62-
self.max_heapify(i)
63-
64-
def get_max(self):
65-
if self.curr_size >= 2:
83+
self.heap_size = len(self.h)
84+
if self.heap_size > 1:
85+
# max_heapify from right to left but exclude leaves (last level)
86+
for i in range(self.heap_size // 2 - 1, -1, -1):
87+
self.max_heapify(i)
88+
89+
def max(self) -> float:
90+
""" return the max in the heap """
91+
if self.heap_size >= 1:
92+
return self.h[0]
93+
else:
94+
raise Exception("Empty heap")
95+
96+
def extract_max(self) -> float:
97+
""" get and remove max from heap """
98+
if self.heap_size >= 2:
6699
me = self.h[0]
67100
self.h[0] = self.h.pop(-1)
68-
self.curr_size -= 1
101+
self.heap_size -= 1
69102
self.max_heapify(0)
70103
return me
71-
elif self.curr_size == 1:
72-
self.curr_size -= 1
104+
elif self.heap_size == 1:
105+
self.heap_size -= 1
73106
return self.h.pop(-1)
74-
return None
75-
76-
def heap_sort(self):
77-
size = self.curr_size
107+
else:
108+
raise Exception("Empty heap")
109+
110+
def insert(self, value: float) -> None:
111+
""" insert a new value into the max heap """
112+
self.h.append(value)
113+
idx = (self.heap_size - 1) // 2
114+
self.heap_size += 1
115+
while idx >= 0:
116+
self.max_heapify(idx)
117+
idx = (idx - 1) // 2
118+
119+
def heap_sort(self) -> None:
120+
size = self.heap_size
78121
for j in range(size - 1, 0, -1):
79122
self.h[0], self.h[j] = self.h[j], self.h[0]
80-
self.curr_size -= 1
123+
self.heap_size -= 1
81124
self.max_heapify(0)
82-
self.curr_size = size
125+
self.heap_size = size
83126

84-
def insert(self, data):
85-
self.h.append(data)
86-
curr = (self.curr_size - 1) // 2
87-
self.curr_size += 1
88-
while curr >= 0:
89-
self.max_heapify(curr)
90-
curr = (curr - 1) // 2
91127

92-
def display(self):
93-
print(self.h)
128+
if __name__ == "__main__":
129+
import doctest
94130

131+
# run doc test
132+
doctest.testmod()
95133

96-
def main():
134+
# demo
97135
for unsorted in [
98-
[],
99136
[0],
100137
[2],
101138
[3, 5],
@@ -110,26 +147,17 @@ def main():
110147
[103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5],
111148
[-45, -2, -5],
112149
]:
113-
print("source unsorted list: %s" % unsorted)
150+
print(f"unsorted array: {unsorted}")
114151

115-
h = Heap()
116-
h.build_heap(unsorted)
117-
print("after build heap: ", end=" ")
118-
h.display()
152+
heap = Heap()
153+
heap.build_max_heap(unsorted)
154+
print(f"after build heap: {heap}")
119155

120-
print("max value: %s" % h.get_max())
121-
print("delete max value: ", end=" ")
122-
h.display()
156+
print(f"max value: {heap.extract_max()}")
157+
print(f"after max value removed: {heap}")
123158

124-
h.insert(100)
125-
print("after insert new value 100: ", end=" ")
126-
h.display()
159+
heap.insert(100)
160+
print(f"after new value 100 inserted: {heap}")
127161

128-
h.heap_sort()
129-
print("heap sort: ", end=" ")
130-
h.display()
131-
print()
132-
133-
134-
if __name__ == "__main__":
135-
main()
162+
heap.heap_sort()
163+
print(f"heap-sorted array: {heap}\n")

0 commit comments

Comments
 (0)