Skip to content

Commit 5f2d6a6

Browse files
committed
[mypy] Makes LRU_Cache generic over key and value types for other/lru_cache
+ no reason to force int -> int
1 parent 0a0d577 commit 5f2d6a6

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

Diff for: other/lru_cache.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Callable
3+
from typing import Callable, Generic, TypeVar
44

5+
T = TypeVar("T")
6+
U = TypeVar("U")
57

6-
class DoubleLinkedListNode:
8+
9+
class DoubleLinkedListNode(Generic[T, U]):
710
"""
811
Double Linked List Node built specifically for LRU Cache
912
@@ -12,19 +15,19 @@ class DoubleLinkedListNode:
1215
Node: key: 1, val: 1, has next: False, has prev: False
1316
"""
1417

15-
def __init__(self, key: int | None, val: int | None):
18+
def __init__(self, key: T | None, val: U | None):
1619
self.key = key
1720
self.val = val
18-
self.next: DoubleLinkedListNode | None = None
19-
self.prev: DoubleLinkedListNode | None = None
21+
self.next: DoubleLinkedListNode[T, U] | None = None
22+
self.prev: DoubleLinkedListNode[T, U] | None = None
2023

2124
def __repr__(self) -> str:
2225
return "Node: key: {}, val: {}, has next: {}, has prev: {}".format(
2326
self.key, self.val, self.next is not None, self.prev is not None
2427
)
2528

2629

27-
class DoubleLinkedList:
30+
class DoubleLinkedList(Generic[T, U]):
2831
"""
2932
Double Linked List built specifically for LRU Cache
3033
@@ -92,8 +95,8 @@ class DoubleLinkedList:
9295
"""
9396

9497
def __init__(self) -> None:
95-
self.head = DoubleLinkedListNode(None, None)
96-
self.rear = DoubleLinkedListNode(None, None)
98+
self.head: DoubleLinkedListNode[T, U] = DoubleLinkedListNode(None, None)
99+
self.rear: DoubleLinkedListNode[T, U] = DoubleLinkedListNode(None, None)
97100
self.head.next, self.rear.prev = self.rear, self.head
98101

99102
def __repr__(self) -> str:
@@ -105,7 +108,7 @@ def __repr__(self) -> str:
105108
rep.append(str(self.rear))
106109
return ",\n ".join(rep)
107110

108-
def add(self, node: DoubleLinkedListNode) -> None:
111+
def add(self, node: DoubleLinkedListNode[T, U]) -> None:
109112
"""
110113
Adds the given node to the end of the list (before rear)
111114
"""
@@ -120,7 +123,7 @@ def add(self, node: DoubleLinkedListNode) -> None:
120123
self.rear.prev = node
121124
node.next = self.rear
122125

123-
def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None:
126+
def remove(self, node: DoubleLinkedListNode[T, U]) -> DoubleLinkedListNode[T, U] | None:
124127
"""
125128
Removes and returns the given node from the list
126129
@@ -140,8 +143,7 @@ def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None:
140143
return node
141144

142145

143-
# class LRUCache(Generic[T]):
144-
class LRUCache:
146+
class LRUCache(Generic[T, U]):
145147
"""
146148
LRU Cache to store a given capacity of data. Can be used as a stand-alone object
147149
or as a function decorator.
@@ -206,17 +208,15 @@ class LRUCache:
206208
"""
207209

208210
# class variable to map the decorator functions to their respective instance
209-
decorator_function_to_instance_map: dict[Callable, LRUCache] = {}
211+
decorator_function_to_instance_map: dict[Callable[[T], U], LRUCache[T, U]] = {}
210212

211213
def __init__(self, capacity: int):
212-
self.list = DoubleLinkedList()
214+
self.list: DoubleLinkedList[T, U] = DoubleLinkedList()
213215
self.capacity = capacity
214216
self.num_keys = 0
215217
self.hits = 0
216218
self.miss = 0
217-
# self.cache: dict[int, int] = {}
218-
# self.cache: dict[int, T] = {}
219-
self.cache: dict[int, DoubleLinkedListNode] = {}
219+
self.cache: dict[T, DoubleLinkedListNode[T, U]] = {}
220220

221221
def __repr__(self) -> str:
222222
"""
@@ -229,7 +229,7 @@ def __repr__(self) -> str:
229229
f"capacity={self.capacity}, current size={self.num_keys})"
230230
)
231231

232-
def __contains__(self, key: int) -> bool:
232+
def __contains__(self, key: T) -> bool:
233233
"""
234234
>>> cache = LRUCache(1)
235235
@@ -244,15 +244,15 @@ def __contains__(self, key: int) -> bool:
244244

245245
return key in self.cache
246246

247-
def get(self, key: int) -> int | None:
247+
def get(self, key: T) -> U | None:
248248
"""
249249
Returns the value for the input key and updates the Double Linked List.
250250
Returns None if key is not present in cache
251251
"""
252252

253253
if key in self.cache:
254254
self.hits += 1
255-
value_node = self.cache[key]
255+
value_node: DoubleLinkedListNode[T, U] = self.cache[key]
256256
node = self.list.remove(self.cache[key])
257257
assert node == value_node
258258

@@ -263,7 +263,7 @@ def get(self, key: int) -> int | None:
263263
self.miss += 1
264264
return None
265265

266-
def set(self, key: int, value: int) -> None:
266+
def set(self, key: T, value: U) -> None:
267267
"""
268268
Sets the value for the input key and updates the Double Linked List
269269
"""
@@ -277,7 +277,9 @@ def set(self, key: int, value: int) -> None:
277277
# explain to type checker via assertions
278278
assert first_node is not None
279279
assert first_node.key is not None
280-
assert self.list.remove(first_node) is not None # node guaranteed to be in list assert node.key is not None
280+
assert (
281+
self.list.remove(first_node) is not None
282+
) # node guaranteed to be in list assert node.key is not None
281283

282284
del self.cache[first_node.key]
283285
self.num_keys -= 1
@@ -293,14 +295,15 @@ def set(self, key: int, value: int) -> None:
293295
self.list.add(node)
294296

295297
@staticmethod
296-
def decorator(size: int = 128) -> Callable[[Callable[[int], int]], Callable[..., int]]:
298+
def decorator(size: int = 128) -> Callable[[Callable[[T], U]], Callable[..., U]]:
297299
"""
298300
Decorator version of LRU Cache
299301
300-
Decorated function must be function of int -> int
302+
Decorated function must be function of T -> U
301303
"""
302-
def cache_decorator_inner(func: Callable[[int], int]) -> Callable[..., int]:
303-
def cache_decorator_wrapper(*args: int) -> int:
304+
305+
def cache_decorator_inner(func: Callable[[T], U]) -> Callable[..., U]:
306+
def cache_decorator_wrapper(*args: T) -> U:
304307
if func not in LRUCache.decorator_function_to_instance_map:
305308
LRUCache.decorator_function_to_instance_map[func] = LRUCache(size)
306309

0 commit comments

Comments
 (0)