Skip to content

Add type hints and docstrings to heap.py #3013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 26, 2020
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 104 additions & 74 deletions data_structures/heap/heap.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,131 @@
#!/usr/bin/python3
from typing import Iterable, List, Union


class Heap:
"""
"""A Max Heap Implementation

>>> unsorted = [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5]
>>> h = Heap()
>>> h.build_heap(unsorted)
>>> h.display()
>>> h.build_max_heap(unsorted)
>>> print(h)
[209, 201, 25, 103, 107, 15, 1, 9, 7, 11, 5]
>>>
>>> h.get_max()
>>> h.extract_max()
209
>>> h.display()
>>> print(h)
[201, 107, 25, 103, 11, 15, 1, 9, 7, 5]
>>>
>>> h.insert(100)
>>> h.display()
>>> print(h)
[201, 107, 25, 103, 100, 15, 1, 9, 7, 5, 11]
>>>
>>> h.heap_sort()
>>> h.display()
>>> print(h)
[1, 5, 7, 9, 11, 15, 25, 100, 103, 107, 201]
>>>
"""

def __init__(self):
self.h = []
self.curr_size = 0
self.h: List[Union[int, float]] = []
self.heap_size: int = 0

def __repr__(self) -> str:
return str(self.h)

def parent_index(self, child_idx: int) -> Union[int, None]:
""" return the parent index of given child """
if child_idx > 0:
return (child_idx - 1) // 2
return None

def get_left_child_index(self, i):
left_child_index = 2 * i + 1
if left_child_index < self.curr_size:
def left_child_idx(self, parent_idx: int) -> Union[int, None]:
"""
return the left child index if the left child exists.
if not, return None.
"""
left_child_index = 2 * parent_idx + 1
if left_child_index < self.heap_size:
return left_child_index
return None

def get_right_child(self, i):
right_child_index = 2 * i + 2
if right_child_index < self.curr_size:
def right_child_idx(self, parent_idx: int) -> Union[int, None]:
"""
return the right child index if the right child exists.
if not, return None.
"""
right_child_index = 2 * parent_idx + 2
if right_child_index < self.heap_size:
return right_child_index
return None

def max_heapify(self, index):
if index < self.curr_size:
largest = index
lc = self.get_left_child_index(index)
rc = self.get_right_child(index)
if lc is not None and self.h[lc] > self.h[largest]:
largest = lc
if rc is not None and self.h[rc] > self.h[largest]:
largest = rc
if largest != index:
self.h[largest], self.h[index] = self.h[index], self.h[largest]
self.max_heapify(largest)

def build_heap(self, collection):
self.curr_size = len(collection)
def max_heapify(self, index: int):
"""
correct a single violation of the heap property in a subtree's root.
"""
if index < self.heap_size:
violation: int = index
left_child = self.left_child_idx(index)
right_child = self.right_child_idx(index)
# check which child is larger than its parent
if left_child is not None and self.h[left_child] > self.h[violation]:
violation = left_child
if right_child is not None and self.h[right_child] > self.h[violation]:
violation = right_child
# if violation indeed exists
if violation != index:
# swap to fix the violation
self.h[violation], self.h[index] = self.h[index], self.h[violation]
# fix the subsequent violation recursively if any
self.max_heapify(violation)

def build_max_heap(self, collection: Iterable[Union[int, float]]):
""" build max heap from an unsorted array"""
self.h = list(collection)
if self.curr_size <= 1:
return
for i in range(self.curr_size // 2 - 1, -1, -1):
self.max_heapify(i)

def get_max(self):
if self.curr_size >= 2:
self.heap_size = len(self.h)
if self.heap_size > 1:
# max_heapify from right to left but exclude leaves (last level)
for i in range(self.heap_size // 2 - 1, -1, -1):
self.max_heapify(i)

def max(self) -> Union[int, float]:
""" return the max in the heap """
if self.heap_size >= 1:
return self.h[0]
else:
raise Exception("Empty heap")

def extract_max(self) -> Union[int, float]:
""" get and remove max from heap """
if self.heap_size >= 2:
me = self.h[0]
self.h[0] = self.h.pop(-1)
self.curr_size -= 1
self.heap_size -= 1
self.max_heapify(0)
return me
elif self.curr_size == 1:
self.curr_size -= 1
elif self.heap_size == 1:
self.heap_size -= 1
return self.h.pop(-1)
return None
else:
raise Exception("Empty heap")

def insert(self, value: Union[int, float]):
""" insert a new value into the max heap """
self.h.append(value)
idx = (self.heap_size - 1) // 2
self.heap_size += 1
while idx >= 0:
self.max_heapify(idx)
idx = (idx - 1) // 2

def heap_sort(self):
size = self.curr_size
size = self.heap_size
for j in range(size - 1, 0, -1):
self.h[0], self.h[j] = self.h[j], self.h[0]
self.curr_size -= 1
self.heap_size -= 1
self.max_heapify(0)
self.curr_size = size

def insert(self, data):
self.h.append(data)
curr = (self.curr_size - 1) // 2
self.curr_size += 1
while curr >= 0:
self.max_heapify(curr)
curr = (curr - 1) // 2
self.heap_size = size

def display(self):
print(self.h)


def main():
def demo():
for unsorted in [
[],
[0],
Expand All @@ -110,26 +142,24 @@ def main():
[103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5],
[-45, -2, -5],
]:
print("source unsorted list: %s" % unsorted)
print(f"unsorted array: {unsorted}")

h = Heap()
h.build_heap(unsorted)
print("after build heap: ", end=" ")
h.display()
heap = Heap()
heap.build_max_heap(unsorted)
print(f"after build heap: {heap}")

print("max value: %s" % h.get_max())
print("delete max value: ", end=" ")
h.display()
print(f"max value: {heap.extract_max()}")
print(f"after max value removed: {heap}")

h.insert(100)
print("after insert new value 100: ", end=" ")
h.display()
heap.insert(100)
print(f"after new value 100 inserted: {heap}")

h.heap_sort()
print("heap sort: ", end=" ")
h.display()
print()
heap.heap_sort()
print(f"heap-sorted array: {heap}\n")


if __name__ == "__main__":
main()
# demo()
import doctest

doctest.testmod()