Skip to content

Commit aa82e72

Browse files
committed
[mypy] Makes LRU_Cache generic over key and value types for other/lru_cache
+ no reason to force int -> int
1 parent 0a0d577 commit aa82e72

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

Diff for: other/lru_cache.py

+31-26
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Callable
3+
from typing import Callable, Generic, TypeVar
44

5+
T = TypeVar("T")
6+
U = TypeVar("U")
57

6-
class DoubleLinkedListNode:
8+
9+
class DoubleLinkedListNode(Generic[T, U]):
710
"""
811
Double Linked List Node built specifically for LRU Cache
912
@@ -12,19 +15,19 @@ class DoubleLinkedListNode:
1215
Node: key: 1, val: 1, has next: False, has prev: False
1316
"""
1417

15-
def __init__(self, key: int | None, val: int | None):
18+
def __init__(self, key: T | None, val: U | None):
1619
self.key = key
1720
self.val = val
18-
self.next: DoubleLinkedListNode | None = None
19-
self.prev: DoubleLinkedListNode | None = None
21+
self.next: DoubleLinkedListNode[T, U] | None = None
22+
self.prev: DoubleLinkedListNode[T, U] | None = None
2023

2124
def __repr__(self) -> str:
2225
return "Node: key: {}, val: {}, has next: {}, has prev: {}".format(
2326
self.key, self.val, self.next is not None, self.prev is not None
2427
)
2528

2629

27-
class DoubleLinkedList:
30+
class DoubleLinkedList(Generic[T, U]):
2831
"""
2932
Double Linked List built specifically for LRU Cache
3033
@@ -92,8 +95,8 @@ class DoubleLinkedList:
9295
"""
9396

9497
def __init__(self) -> None:
95-
self.head = DoubleLinkedListNode(None, None)
96-
self.rear = DoubleLinkedListNode(None, None)
98+
self.head: DoubleLinkedListNode[T, U] = DoubleLinkedListNode(None, None)
99+
self.rear: DoubleLinkedListNode[T, U] = DoubleLinkedListNode(None, None)
97100
self.head.next, self.rear.prev = self.rear, self.head
98101

99102
def __repr__(self) -> str:
@@ -105,7 +108,7 @@ def __repr__(self) -> str:
105108
rep.append(str(self.rear))
106109
return ",\n ".join(rep)
107110

108-
def add(self, node: DoubleLinkedListNode) -> None:
111+
def add(self, node: DoubleLinkedListNode[T, U]) -> None:
109112
"""
110113
Adds the given node to the end of the list (before rear)
111114
"""
@@ -120,7 +123,9 @@ def add(self, node: DoubleLinkedListNode) -> None:
120123
self.rear.prev = node
121124
node.next = self.rear
122125

123-
def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None:
126+
def remove(
127+
self, node: DoubleLinkedListNode[T, U]
128+
) -> DoubleLinkedListNode[T, U] | None:
124129
"""
125130
Removes and returns the given node from the list
126131
@@ -140,8 +145,7 @@ def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None:
140145
return node
141146

142147

143-
# class LRUCache(Generic[T]):
144-
class LRUCache:
148+
class LRUCache(Generic[T, U]):
145149
"""
146150
LRU Cache to store a given capacity of data. Can be used as a stand-alone object
147151
or as a function decorator.
@@ -206,17 +210,15 @@ class LRUCache:
206210
"""
207211

208212
# class variable to map the decorator functions to their respective instance
209-
decorator_function_to_instance_map: dict[Callable, LRUCache] = {}
213+
decorator_function_to_instance_map: dict[Callable[[T], U], LRUCache[T, U]] = {}
210214

211215
def __init__(self, capacity: int):
212-
self.list = DoubleLinkedList()
216+
self.list: DoubleLinkedList[T, U] = DoubleLinkedList()
213217
self.capacity = capacity
214218
self.num_keys = 0
215219
self.hits = 0
216220
self.miss = 0
217-
# self.cache: dict[int, int] = {}
218-
# self.cache: dict[int, T] = {}
219-
self.cache: dict[int, DoubleLinkedListNode] = {}
221+
self.cache: dict[T, DoubleLinkedListNode[T, U]] = {}
220222

221223
def __repr__(self) -> str:
222224
"""
@@ -229,7 +231,7 @@ def __repr__(self) -> str:
229231
f"capacity={self.capacity}, current size={self.num_keys})"
230232
)
231233

232-
def __contains__(self, key: int) -> bool:
234+
def __contains__(self, key: T) -> bool:
233235
"""
234236
>>> cache = LRUCache(1)
235237
@@ -244,15 +246,15 @@ def __contains__(self, key: int) -> bool:
244246

245247
return key in self.cache
246248

247-
def get(self, key: int) -> int | None:
249+
def get(self, key: T) -> U | None:
248250
"""
249251
Returns the value for the input key and updates the Double Linked List.
250252
Returns None if key is not present in cache
251253
"""
252254

253255
if key in self.cache:
254256
self.hits += 1
255-
value_node = self.cache[key]
257+
value_node: DoubleLinkedListNode[T, U] = self.cache[key]
256258
node = self.list.remove(self.cache[key])
257259
assert node == value_node
258260

@@ -263,7 +265,7 @@ def get(self, key: int) -> int | None:
263265
self.miss += 1
264266
return None
265267

266-
def set(self, key: int, value: int) -> None:
268+
def set(self, key: T, value: U) -> None:
267269
"""
268270
Sets the value for the input key and updates the Double Linked List
269271
"""
@@ -277,7 +279,9 @@ def set(self, key: int, value: int) -> None:
277279
# explain to type checker via assertions
278280
assert first_node is not None
279281
assert first_node.key is not None
280-
assert self.list.remove(first_node) is not None # node guaranteed to be in list assert node.key is not None
282+
assert (
283+
self.list.remove(first_node) is not None
284+
) # node guaranteed to be in list assert node.key is not None
281285

282286
del self.cache[first_node.key]
283287
self.num_keys -= 1
@@ -293,14 +297,15 @@ def set(self, key: int, value: int) -> None:
293297
self.list.add(node)
294298

295299
@staticmethod
296-
def decorator(size: int = 128) -> Callable[[Callable[[int], int]], Callable[..., int]]:
300+
def decorator(size: int = 128) -> Callable[[Callable[[T], U]], Callable[..., U]]:
297301
"""
298302
Decorator version of LRU Cache
299303
300-
Decorated function must be function of int -> int
304+
Decorated function must be function of T -> U
301305
"""
302-
def cache_decorator_inner(func: Callable[[int], int]) -> Callable[..., int]:
303-
def cache_decorator_wrapper(*args: int) -> int:
306+
307+
def cache_decorator_inner(func: Callable[[T], U]) -> Callable[..., U]:
308+
def cache_decorator_wrapper(*args: T) -> U:
304309
if func not in LRUCache.decorator_function_to_instance_map:
305310
LRUCache.decorator_function_to_instance_map[func] = LRUCache(size)
306311

0 commit comments

Comments
 (0)