1
1
from __future__ import annotations
2
2
3
+ from abc import abstractmethod
3
4
from collections .abc import Iterable
5
+ from typing import Generic , Protocol , TypeVar
4
6
5
7
6
- class Heap :
8
+ class Comparable (Protocol ):
9
+ @abstractmethod
10
+ def __lt__ (self : T , other : T ) -> bool :
11
+ pass
12
+
13
+ @abstractmethod
14
+ def __gt__ (self : T , other : T ) -> bool :
15
+ pass
16
+
17
+ @abstractmethod
18
+ def __eq__ (self : T , other : object ) -> bool :
19
+ pass
20
+
21
+
22
+ T = TypeVar ("T" , bound = Comparable )
23
+
24
+
25
+ class Heap (Generic [T ]):
7
26
"""A Max Heap Implementation
8
27
9
28
>>> unsorted = [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5]
@@ -27,7 +46,7 @@ class Heap:
27
46
"""
28
47
29
48
def __init__ (self ) -> None :
30
- self .h : list [float ] = []
49
+ self .h : list [T ] = []
31
50
self .heap_size : int = 0
32
51
33
52
def __repr__ (self ) -> str :
@@ -79,7 +98,7 @@ def max_heapify(self, index: int) -> None:
79
98
# fix the subsequent violation recursively if any
80
99
self .max_heapify (violation )
81
100
82
- def build_max_heap (self , collection : Iterable [float ]) -> None :
101
+ def build_max_heap (self , collection : Iterable [T ]) -> None :
83
102
"""build max heap from an unsorted array"""
84
103
self .h = list (collection )
85
104
self .heap_size = len (self .h )
@@ -88,7 +107,7 @@ def build_max_heap(self, collection: Iterable[float]) -> None:
88
107
for i in range (self .heap_size // 2 - 1 , - 1 , - 1 ):
89
108
self .max_heapify (i )
90
109
91
- def extract_max (self ) -> float :
110
+ def extract_max (self ) -> T :
92
111
"""get and remove max from heap"""
93
112
if self .heap_size >= 2 :
94
113
me = self .h [0 ]
@@ -102,7 +121,7 @@ def extract_max(self) -> float:
102
121
else :
103
122
raise Exception ("Empty heap" )
104
123
105
- def insert (self , value : float ) -> None :
124
+ def insert (self , value : T ) -> None :
106
125
"""insert a new value into the max heap"""
107
126
self .h .append (value )
108
127
idx = (self .heap_size - 1 ) // 2
@@ -144,7 +163,7 @@ def heap_sort(self) -> None:
144
163
]:
145
164
print (f"unsorted array: { unsorted } " )
146
165
147
- heap = Heap ()
166
+ heap : Heap [ int ] = Heap ()
148
167
heap .build_max_heap (unsorted )
149
168
print (f"after build heap: { heap } " )
150
169
0 commit comments