Skip to content

Commit 059a7c0

Browse files
committed
Solving the Top k most frequent words problem using a max-heap
1 parent c909da9 commit 059a7c0

File tree

2 files changed

+86
-6
lines changed

2 files changed

+86
-6
lines changed

Diff for: data_structures/heap/heap.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
11
from __future__ import annotations
22

3+
from abc import abstractmethod
34
from collections.abc import Iterable
5+
from typing import Generic, Protocol, TypeVar
46

57

6-
class Heap:
8+
class Comparable(Protocol):
9+
@abstractmethod
10+
def __lt__(self: T, other: T) -> bool:
11+
pass
12+
13+
@abstractmethod
14+
def __gt__(self: T, other: T) -> bool:
15+
pass
16+
17+
@abstractmethod
18+
def __eq__(self: T, other: object) -> bool:
19+
pass
20+
21+
22+
T = TypeVar("T", bound=Comparable)
23+
24+
25+
class Heap(Generic[T]):
726
"""A Max Heap Implementation
827
928
>>> unsorted = [103, 9, 1, 7, 11, 15, 25, 201, 209, 107, 5]
@@ -27,7 +46,7 @@ class Heap:
2746
"""
2847

2948
def __init__(self) -> None:
30-
self.h: list[float] = []
49+
self.h: list[T] = []
3150
self.heap_size: int = 0
3251

3352
def __repr__(self) -> str:
@@ -79,7 +98,7 @@ def max_heapify(self, index: int) -> None:
7998
# fix the subsequent violation recursively if any
8099
self.max_heapify(violation)
81100

82-
def build_max_heap(self, collection: Iterable[float]) -> None:
101+
def build_max_heap(self, collection: Iterable[T]) -> None:
83102
"""build max heap from an unsorted array"""
84103
self.h = list(collection)
85104
self.heap_size = len(self.h)
@@ -88,7 +107,7 @@ def build_max_heap(self, collection: Iterable[float]) -> None:
88107
for i in range(self.heap_size // 2 - 1, -1, -1):
89108
self.max_heapify(i)
90109

91-
def extract_max(self) -> float:
110+
def extract_max(self) -> T:
92111
"""get and remove max from heap"""
93112
if self.heap_size >= 2:
94113
me = self.h[0]
@@ -102,7 +121,7 @@ def extract_max(self) -> float:
102121
else:
103122
raise Exception("Empty heap")
104123

105-
def insert(self, value: float) -> None:
124+
def insert(self, value: T) -> None:
106125
"""insert a new value into the max heap"""
107126
self.h.append(value)
108127
idx = (self.heap_size - 1) // 2
@@ -144,7 +163,7 @@ def heap_sort(self) -> None:
144163
]:
145164
print(f"unsorted array: {unsorted}")
146165

147-
heap = Heap()
166+
heap: Heap[int] = Heap()
148167
heap.build_max_heap(unsorted)
149168
print(f"after build heap: {heap}")
150169

Diff for: strings/top_k_frequent_words.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Finds the top K most frequent words from the provided word list.
3+
"""
4+
5+
6+
from collections import Counter
7+
from functools import total_ordering
8+
9+
from data_structures.heap.heap import Heap
10+
11+
12+
@total_ordering
13+
class WordCount:
14+
def __init__(self, word: str, count: int):
15+
self.word = word
16+
self.count = count
17+
18+
def __eq__(self, other) -> bool:
19+
if not isinstance(other, WordCount):
20+
return NotImplemented
21+
return self.count == other.count
22+
23+
def __lt__(self, other) -> bool:
24+
if not isinstance(other, WordCount):
25+
return NotImplemented
26+
return self.count < other.count
27+
28+
29+
def top_k_frequent_words(words: list[str], k: int) -> list[str]:
30+
"""
31+
Returns the k most frequently occurring words in non-increasing order of occurrence.
32+
In this context, a word is defined as an element in the provided list.
33+
34+
In case k is greater than the number of distinct words, a value of k equal
35+
to the number of distinct words will be considered, instead.
36+
37+
>>> top_k_frequent_words(['a', 'b', 'c', 'a', 'c', 'c'], 3)
38+
['c', 'a', 'b']
39+
>>> top_k_frequent_words(['a', 'b', 'c', 'a', 'c', 'c'], 2)
40+
['c', 'a']
41+
>>> top_k_frequent_words(['a', 'b', 'c', 'a', 'c', 'c'], 1)
42+
['c']
43+
>>> top_k_frequent_words(['a', 'b', 'c', 'a', 'c', 'c'], 0)
44+
[]
45+
>>> top_k_frequent_words([], 1)
46+
[]
47+
>>> top_k_frequent_words(['a', 'a'], 2)
48+
['a']
49+
"""
50+
heap: Heap[WordCount] = Heap()
51+
count_by_word = Counter(words)
52+
heap.build_max_heap(
53+
[WordCount(word, count) for word, count in count_by_word.items()]
54+
)
55+
return [heap.extract_max().word for _ in range(min(k, len(count_by_word)))]
56+
57+
58+
if __name__ == "__main__":
59+
import doctest
60+
61+
doctest.testmod()

0 commit comments

Comments
 (0)