diff --git a/ciphers/atbash.py b/ciphers/atbash.py index 4e8f663ed02d..58561159b4fe 100644 --- a/ciphers/atbash.py +++ b/ciphers/atbash.py @@ -1,10 +1,14 @@ """https://en.wikipedia.org/wiki/Atbash""" import string +from timeit import timeit # Moved to top-level as required def atbash_slow(sequence: str) -> str: """ + Atbash cipher implementation using ordinal values. + Encodes/decodes by reversing the alphabet. + >>> atbash_slow("ABCDEFG") 'ZYXWVUT' @@ -12,43 +16,60 @@ def atbash_slow(sequence: str) -> str: 'zD;;123YC' """ output = "" - for i in sequence: - extract = ord(i) - if 65 <= extract <= 90: - output += chr(155 - extract) - elif 97 <= extract <= 122: - output += chr(219 - extract) - else: - output += i + for char in sequence: + code = ord(char) + if 65 <= code <= 90: # Uppercase A-Z + output += chr(155 - code) + elif 97 <= code <= 122: # Lowercase a-z + output += chr(219 - code) + else: # Non-alphabetic characters + output += char return output def atbash(sequence: str) -> str: """ + Optimized Atbash cipher implementation using string translation. + More efficient than ordinal-based approach. + >>> atbash("ABCDEFG") 'ZYXWVUT' >>> atbash("aW;;123BX") 'zD;;123YC' """ + # Create translation tables letters = string.ascii_letters - letters_reversed = string.ascii_lowercase[::-1] + string.ascii_uppercase[::-1] - return "".join( - letters_reversed[letters.index(c)] if c in letters else c for c in sequence - ) + reversed_letters = string.ascii_lowercase[::-1] + string.ascii_uppercase[::-1] + # Create translation mapping + translation = str.maketrans(letters, reversed_letters) -def benchmark() -> None: - """Let's benchmark our functions side-by-side...""" - from timeit import timeit + # Apply translation to each character + return sequence.translate(translation) + +def benchmark() -> None: + """ + Performance comparison of both Atbash implementations. + Measures execution time using Python's timeit module. + """ print("Running performance benchmarks...") - setup = "from string import printable ; from __main__ import atbash, atbash_slow" - print(f"> atbash_slow(): {timeit('atbash_slow(printable)', setup=setup)} seconds") - print(f"> atbash(): {timeit('atbash(printable)', setup=setup)} seconds") + setup = "from string import printable; from __main__ import atbash, atbash_slow" + # Time the slow implementation + slow_time = timeit("atbash_slow(printable)", setup=setup) + print(f"> atbash_slow(): {slow_time:.6f} seconds") + + # Time the optimized implementation + fast_time = timeit("atbash(printable)", setup=setup) + print(f"> atbash(): {fast_time:.6f} seconds") if __name__ == "__main__": - for example in ("ABCDEFGH", "123GGjj", "testStringtest", "with space"): + # Test examples + examples = ("ABCDEFGH", "123GGjj", "testStringtest", "with space") + for example in examples: print(f"{example} encrypted in atbash: {atbash(example)}") + + # Run performance comparison benchmark() diff --git a/ciphers/hill_cipher.py b/ciphers/hill_cipher.py index 33b2529f017b..e27c4a2f1cbc 100644 --- a/ciphers/hill_cipher.py +++ b/ciphers/hill_cipher.py @@ -44,176 +44,111 @@ class HillCipher: - key_string = string.ascii_uppercase + string.digits - # This cipher takes alphanumerics into account - # i.e. a total of 36 characters - - # take x and return x % len(key_string) - modulus = np.vectorize(lambda x: x % 36) - - to_int = np.vectorize(round) + key_string = string.ascii_uppercase + string.digits # 36 chars + modulus = np.vectorize(lambda x: x % 36) # Mod 36 + to_int = np.vectorize(round) # Round numbers def __init__(self, encrypt_key: np.ndarray) -> None: - """ - encrypt_key is an NxN numpy array - """ - self.encrypt_key = self.modulus(encrypt_key) # mod36 calc's on the encrypt key - self.check_determinant() # validate the determinant of the encryption key - self.break_key = encrypt_key.shape[0] + self.encrypt_key = self.modulus(encrypt_key) + self.check_determinant() # Validate key + self.break_key = encrypt_key.shape[0] # Matrix size def replace_letters(self, letter: str) -> int: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.replace_letters('T') - 19 - >>> hill_cipher.replace_letters('0') - 26 - """ + """Char to index (A=0, 0=26)""" return self.key_string.index(letter) def replace_digits(self, num: int) -> str: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.replace_digits(19) - 'T' - >>> hill_cipher.replace_digits(26) - '0' - """ - return self.key_string[round(num)] + """Index to char""" + return self.key_string[num] def check_determinant(self) -> None: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.check_determinant() - """ + """Ensure det(key) coprime with 36""" det = round(np.linalg.det(self.encrypt_key)) - if det < 0: - det = det % len(self.key_string) + det %= len(self.key_string) - req_l = len(self.key_string) + error_msg = f"Det {det} not coprime with 36. Try another key." if greatest_common_divisor(det, len(self.key_string)) != 1: - msg = ( - f"determinant modular {req_l} of encryption key({det}) " - f"is not co prime w.r.t {req_l}.\nTry another key." - ) - raise ValueError(msg) + raise ValueError(error_msg) def process_text(self, text: str) -> str: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.process_text('Testing Hill Cipher') - 'TESTINGHILLCIPHERR' - >>> hill_cipher.process_text('hello') - 'HELLOO' - """ - chars = [char for char in text.upper() if char in self.key_string] - - last = chars[-1] + """Uppercase, filter, pad text""" + chars = [c for c in text.upper() if c in self.key_string] + last = chars[-1] if chars else "A" while len(chars) % self.break_key != 0: chars.append(last) - return "".join(chars) def encrypt(self, text: str) -> str: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.encrypt('testing hill cipher') - 'WHXYJOLM9C6XT085LL' - >>> hill_cipher.encrypt('hello') - '85FF00' - """ + """Encrypt with Hill cipher""" text = self.process_text(text.upper()) encrypted = "" - - for i in range(0, len(text) - self.break_key + 1, self.break_key): + for i in range(0, len(text), self.break_key): batch = text[i : i + self.break_key] - vec = [self.replace_letters(char) for char in batch] + vec = [self.replace_letters(c) for c in batch] batch_vec = np.array([vec]).T - batch_encrypted = self.modulus(self.encrypt_key.dot(batch_vec)).T.tolist()[ - 0 - ] + product = self.encrypt_key.dot(batch_vec) + modulated = self.modulus(product) + batch_encrypted = modulated.T.tolist()[0] encrypted_batch = "".join( - self.replace_digits(num) for num in batch_encrypted + self.replace_digits(round(n)) for n in batch_encrypted ) encrypted += encrypted_batch - return encrypted def make_decrypt_key(self) -> np.ndarray: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.make_decrypt_key() - array([[ 6, 25], - [ 5, 26]]) - """ + """Create decryption key""" det = round(np.linalg.det(self.encrypt_key)) - if det < 0: - det = det % len(self.key_string) - det_inv = None - for i in range(len(self.key_string)): - if (det * i) % len(self.key_string) == 1: - det_inv = i - break + det %= len(self.key_string) + + # Find det modular inverse + det_inv = next(i for i in range(36) if (det * i) % 36 == 1) + # Compute inverse key inv_key = ( det_inv * np.linalg.det(self.encrypt_key) * np.linalg.inv(self.encrypt_key) ) - return self.to_int(self.modulus(inv_key)) def decrypt(self, text: str) -> str: - """ - >>> hill_cipher = HillCipher(np.array([[2, 5], [1, 6]])) - >>> hill_cipher.decrypt('WHXYJOLM9C6XT085LL') - 'TESTINGHILLCIPHERR' - >>> hill_cipher.decrypt('85FF00') - 'HELLOO' - """ + """Decrypt with Hill cipher""" decrypt_key = self.make_decrypt_key() text = self.process_text(text.upper()) decrypted = "" - - for i in range(0, len(text) - self.break_key + 1, self.break_key): + for i in range(0, len(text), self.break_key): batch = text[i : i + self.break_key] - vec = [self.replace_letters(char) for char in batch] + vec = [self.replace_letters(c) for c in batch] batch_vec = np.array([vec]).T - batch_decrypted = self.modulus(decrypt_key.dot(batch_vec)).T.tolist()[0] + product = decrypt_key.dot(batch_vec) + modulated = self.modulus(product) + batch_decrypted = modulated.T.tolist()[0] decrypted_batch = "".join( - self.replace_digits(num) for num in batch_decrypted + self.replace_digits(round(n)) for n in batch_decrypted ) decrypted += decrypted_batch - return decrypted def main() -> None: - n = int(input("Enter the order of the encryption key: ")) - hill_matrix = [] + """Hill Cipher CLI""" + n = int(input("Enter key order: ")) + print(f"Enter {n} rows of space-separated integers:") + matrix = [list(map(int, input().split())) for _ in range(n)] - print("Enter each row of the encryption key with space separated integers") - for _ in range(n): - row = [int(x) for x in input().split()] - hill_matrix.append(row) + hc = HillCipher(np.array(matrix)) - hc = HillCipher(np.array(hill_matrix)) + option = input("1. Encrypt\n2. Decrypt\nChoose: ") + text = input("Enter text: ") - print("Would you like to encrypt or decrypt some text? (1 or 2)") - option = input("\n1. Encrypt\n2. Decrypt\n") if option == "1": - text_e = input("What text would you like to encrypt?: ") - print("Your encrypted text is:") - print(hc.encrypt(text_e)) + print("Encrypted:", hc.encrypt(text)) elif option == "2": - text_d = input("What text would you like to decrypt?: ") - print("Your decrypted text is:") - print(hc.decrypt(text_d)) + print("Decrypted:", hc.decrypt(text)) if __name__ == "__main__": import doctest doctest.testmod() - main() diff --git a/ciphers/shuffled_shift_cipher.py b/ciphers/shuffled_shift_cipher.py index 08b2cab97c69..e1a40ce303a2 100644 --- a/ciphers/shuffled_shift_cipher.py +++ b/ciphers/shuffled_shift_cipher.py @@ -6,179 +6,113 @@ class ShuffledShiftCipher: """ - This algorithm uses the Caesar Cipher algorithm but removes the option to - use brute force to decrypt the message. + Enhanced Caesar Cipher with shuffled character set for stronger encryption. + Uses a passcode to generate a unique shuffled key list and shift key. - The passcode is a random password from the selection buffer of - 1. uppercase letters of the English alphabet - 2. lowercase letters of the English alphabet - 3. digits from 0 to 9 - - Using unique characters from the passcode, the normal list of characters, - that can be allowed in the plaintext, is pivoted and shuffled. Refer to docstring - of __make_key_list() to learn more about the shuffling. - - Then, using the passcode, a number is calculated which is used to encrypt the - plaintext message with the normal shift cipher method, only in this case, the - reference, to look back at while decrypting, is shuffled. + The passcode is a random password from the selection buffer of: + 1. Uppercase letters of the English alphabet + 2. Lowercase letters of the English alphabet + 3. Digits from 0 to 9 Each cipher object can possess an optional argument as passcode, without which a new passcode is generated for that object automatically. - cip1 = ShuffledShiftCipher('d4usr9TWxw9wMD') - cip2 = ShuffledShiftCipher() + + Example: + >>> cip1 = ShuffledShiftCipher('d4usr9TWxw9wMD') + >>> cip2 = ShuffledShiftCipher() """ def __init__(self, passcode: str | None = None) -> None: """ - Initializes a cipher object with a passcode as it's entity - Note: No new passcode is generated if user provides a passcode - while creating the object + Initialize cipher with optional passcode. + Generates random passcode if none provided. """ self.__passcode = passcode or self.__passcode_creator() self.__key_list = self.__make_key_list() self.__shift_key = self.__make_shift_key() def __str__(self) -> str: - """ - :return: passcode of the cipher object - """ + """Return passcode as string representation.""" return "".join(self.__passcode) - def __neg_pos(self, iterlist: list[int]) -> list[int]: - """ - Mutates the list by changing the sign of each alternate element - - :param iterlist: takes a list iterable - :return: the mutated list - - """ - for i in range(1, len(iterlist), 2): - iterlist[i] *= -1 - return iterlist + def __neg_pos(self, iter_list: list[int]) -> list[int]: + """Alternate sign of elements in list.""" + for i in range(1, len(iter_list), 2): + iter_list[i] *= -1 + return iter_list def __passcode_creator(self) -> list[str]: - """ - Creates a random password from the selection buffer of - 1. uppercase letters of the English alphabet - 2. lowercase letters of the English alphabet - 3. digits from 0 to 9 - - :rtype: list - :return: a password of a random length between 10 to 20 - """ + """Generate random passcode.""" choices = string.ascii_letters + string.digits - password = [random.choice(choices) for _ in range(random.randint(10, 20))] - return password + pass_len = random.randint(10, 20) + return [random.choice(choices) for _ in range(pass_len)] def __make_key_list(self) -> list[str]: - """ - Shuffles the ordered character choices by pivoting at breakpoints - Breakpoints are the set of characters in the passcode - - eg: - if, ABCDEFGHIJKLMNOPQRSTUVWXYZ are the possible characters - and CAMERA is the passcode - then, breakpoints = [A,C,E,M,R] # sorted set of characters from passcode - shuffled parts: [A,CB,ED,MLKJIHGF,RQPON,ZYXWVUTS] - shuffled __key_list : ACBEDMLKJIHGFRQPONZYXWVUTS - - Shuffling only 26 letters of the english alphabet can generate 26! - combinations for the shuffled list. In the program we consider, a set of - 97 characters (including letters, digits, punctuation and whitespaces), - thereby creating a possibility of 97! combinations (which is a 152 digit number - in itself), thus diminishing the possibility of a brute force approach. - Moreover, shift keys even introduce a multiple of 26 for a brute force approach - for each of the already 97! combinations. - """ - # key_list_options contain nearly all printable except few elements from - # string.whitespace - key_list_options = ( - string.ascii_letters + string.digits + string.punctuation + " \t\n" - ) - - keys_l = [] - - # creates points known as breakpoints to break the key_list_options at those - # points and pivot each substring + """Create shuffled character set using passcode breakpoints.""" + # Get printable characters except rare whitespace + key_options = string.printable.strip("\r\x0b\x0c") breakpoints = sorted(set(self.__passcode)) - temp_list: list[str] = [] + shuffled: list[str] = [] + temp: list[str] = [] - # algorithm for creating a new shuffled list, keys_l, out of key_list_options - for i in key_list_options: - temp_list.extend(i) + for char in key_options: + temp.append(char) + if char in breakpoints or char == key_options[-1]: + shuffled.extend(reversed(temp)) + temp.clear() - # checking breakpoints at which to pivot temporary sublist and add it into - # keys_l - if i in breakpoints or i == key_list_options[-1]: - keys_l.extend(temp_list[::-1]) - temp_list.clear() - - # returning a shuffled keys_l to prevent brute force guessing of shift key - return keys_l + return shuffled def __make_shift_key(self) -> int: - """ - sum() of the mutated list of ascii values of all characters where the - mutated list is the one returned by __neg_pos() - """ - num = sum(self.__neg_pos([ord(x) for x in self.__passcode])) + """Calculate shift key from passcode ASCII values.""" + ascii_vals = [ord(x) for x in self.__passcode] + num = sum(self.__neg_pos(ascii_vals)) return num if num > 0 else len(self.__passcode) - def decrypt(self, encoded_message: str) -> str: - """ - Performs shifting of the encoded_message w.r.t. the shuffled __key_list - to create the decoded_message - - >>> ssc = ShuffledShiftCipher('4PYIXyqeQZr44') - >>> ssc.decrypt("d>**-1z6&'5z'5z:z+-='$'>=zp:>5:#z<'.&>#") - 'Hello, this is a modified Caesar cipher' - - """ - decoded_message = "" - - # decoding shift like Caesar cipher algorithm implementing negative shift or - # reverse shift or left shift - for i in encoded_message: - position = self.__key_list.index(i) - decoded_message += self.__key_list[ - (position - self.__shift_key) % -len(self.__key_list) - ] - - return decoded_message - def encrypt(self, plaintext: str) -> str: - """ - Performs shifting of the plaintext w.r.t. the shuffled __key_list - to create the encoded_message + """Encrypt plaintext using shuffled shift cipher.""" + encoded: list[str] = [] + key_len = len(self.__key_list) - >>> ssc = ShuffledShiftCipher('4PYIXyqeQZr44') - >>> ssc.encrypt('Hello, this is a modified Caesar cipher') - "d>**-1z6&'5z'5z:z+-='$'>=zp:>5:#z<'.&>#" + for char in plaintext: + pos = self.__key_list.index(char) + new_pos = (pos + self.__shift_key) % key_len + encoded.append(self.__key_list[new_pos]) - """ - encoded_message = "" + return "".join(encoded) - # encoding shift like Caesar cipher algorithm implementing positive shift or - # forward shift or right shift - for i in plaintext: - position = self.__key_list.index(i) - encoded_message += self.__key_list[ - (position + self.__shift_key) % len(self.__key_list) - ] + def decrypt(self, encoded_message: str) -> str: + """Decrypt message using shuffled shift cipher.""" + decoded: list[str] = [] + key_len = len(self.__key_list) - return encoded_message + for char in encoded_message: + pos = self.__key_list.index(char) + new_pos = (pos - self.__shift_key) % key_len + decoded.append(self.__key_list[new_pos]) + return "".join(decoded) -def test_end_to_end(msg: str = "Hello, this is a modified Caesar cipher") -> str: - """ - >>> test_end_to_end() - 'Hello, this is a modified Caesar cipher' - """ - cip1 = ShuffledShiftCipher() - return cip1.decrypt(cip1.encrypt(msg)) + +def test_end_to_end() -> str: + """Test full encryption-decryption cycle.""" + msg = "Hello, this is a modified Caesar cipher" + cipher = ShuffledShiftCipher() + return cipher.decrypt(cipher.encrypt(msg)) if __name__ == "__main__": import doctest doctest.testmod() + + # Example usage + cipher = ShuffledShiftCipher("SecurePass123") + original = "Encryption test!" + encrypted = cipher.encrypt(original) + decrypted = cipher.decrypt(encrypted) + + print(f"Original: {original}") + print(f"Encrypted: {encrypted}") + print(f"Decrypted: {decrypted}") + print(f"Test passed: {decrypted == original}") diff --git a/data_structures/binary_tree/avl_tree.py b/data_structures/binary_tree/avl_tree.py index 8558305eefe4..482e8d40328b 100644 --- a/data_structures/binary_tree/avl_tree.py +++ b/data_structures/binary_tree/avl_tree.py @@ -1,280 +1,163 @@ -""" -Implementation of an auto-balanced binary tree! -For doctests run following command: -python3 -m doctest -v avl_tree.py -For testing run: -python avl_tree.py -""" - from __future__ import annotations -import math import random from typing import Any class MyQueue: + __slots__ = ("data", "head", "tail") + def __init__(self) -> None: self.data: list[Any] = [] - self.head: int = 0 - self.tail: int = 0 + self.head = self.tail = 0 def is_empty(self) -> bool: return self.head == self.tail def push(self, data: Any) -> None: self.data.append(data) - self.tail = self.tail + 1 + self.tail += 1 def pop(self) -> Any: ret = self.data[self.head] - self.head = self.head + 1 + self.head += 1 return ret - def count(self) -> int: - return self.tail - self.head - - def print_queue(self) -> None: - print(self.data) - print("**************") - print(self.data[self.head : self.tail]) - class MyNode: + __slots__ = ("data", "height", "left", "right") + def __init__(self, data: Any) -> None: self.data = data + self.height = 1 self.left: MyNode | None = None self.right: MyNode | None = None - self.height: int = 1 - def get_data(self) -> Any: - return self.data - def get_left(self) -> MyNode | None: - return self.left +def get_height(node: MyNode | None) -> int: + return node.height if node else 0 - def get_right(self) -> MyNode | None: - return self.right - def get_height(self) -> int: - return self.height +def my_max(a: int, b: int) -> int: + return a if a > b else b - def set_data(self, data: Any) -> None: - self.data = data - def set_left(self, node: MyNode | None) -> None: - self.left = node +def right_rotation(node: MyNode) -> MyNode: + left_child = node.left + if left_child is None: + return node - def set_right(self, node: MyNode | None) -> None: - self.right = node + node.left = left_child.right + left_child.right = node - def set_height(self, height: int) -> None: - self.height = height + # 拆分长表达式 + node_height = my_max(get_height(node.right), get_height(node.left)) + 1 + node.height = node_height + left_height = my_max(get_height(left_child.right), get_height(left_child.left)) + 1 + left_child.height = left_height -def get_height(node: MyNode | None) -> int: - if node is None: - return 0 - return node.get_height() + return left_child -def my_max(a: int, b: int) -> int: - if a > b: - return a - return b +def left_rotation(node: MyNode) -> MyNode: + right_child = node.right + if right_child is None: + return node + node.right = right_child.left + right_child.left = node -def right_rotation(node: MyNode) -> MyNode: - r""" - A B - / \ / \ - B C Bl A - / \ --> / / \ - Bl Br UB Br C - / - UB - UB = unbalanced node - """ - print("left rotation node:", node.get_data()) - ret = node.get_left() - assert ret is not None - node.set_left(ret.get_right()) - ret.set_right(node) - h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 - node.set_height(h1) - h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1 - ret.set_height(h2) - return ret + # 拆分长表达式 + node_height = my_max(get_height(node.right), get_height(node.left)) + 1 + node.height = node_height + right_height = ( + my_max(get_height(right_child.right), get_height(right_child.left)) + 1 + ) + right_child.height = right_height -def left_rotation(node: MyNode) -> MyNode: - """ - a mirror symmetry rotation of the left_rotation - """ - print("right rotation node:", node.get_data()) - ret = node.get_right() - assert ret is not None - node.set_right(ret.get_left()) - ret.set_left(node) - h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 - node.set_height(h1) - h2 = my_max(get_height(ret.get_right()), get_height(ret.get_left())) + 1 - ret.set_height(h2) - return ret + return right_child def lr_rotation(node: MyNode) -> MyNode: - r""" - A A Br - / \ / \ / \ - B C LR Br C RR B A - / \ --> / \ --> / / \ - Bl Br B UB Bl UB C - \ / - UB Bl - RR = right_rotation LR = left_rotation - """ - left_child = node.get_left() - assert left_child is not None - node.set_left(left_rotation(left_child)) + if node.left: + node.left = left_rotation(node.left) return right_rotation(node) def rl_rotation(node: MyNode) -> MyNode: - right_child = node.get_right() - assert right_child is not None - node.set_right(right_rotation(right_child)) + if node.right: + node.right = right_rotation(node.right) return left_rotation(node) def insert_node(node: MyNode | None, data: Any) -> MyNode | None: if node is None: return MyNode(data) - if data < node.get_data(): - node.set_left(insert_node(node.get_left(), data)) - if ( - get_height(node.get_left()) - get_height(node.get_right()) == 2 - ): # an unbalance detected - left_child = node.get_left() - assert left_child is not None - if ( - data < left_child.get_data() - ): # new node is the left child of the left child + + if data < node.data: + node.left = insert_node(node.left, data) + if get_height(node.left) - get_height(node.right) == 2: + if node.left and data < node.left.data: node = right_rotation(node) else: node = lr_rotation(node) else: - node.set_right(insert_node(node.get_right(), data)) - if get_height(node.get_right()) - get_height(node.get_left()) == 2: - right_child = node.get_right() - assert right_child is not None - if data < right_child.get_data(): + node.right = insert_node(node.right, data) + if get_height(node.right) - get_height(node.left) == 2: + if node.right and data < node.right.data: node = rl_rotation(node) else: node = left_rotation(node) - h1 = my_max(get_height(node.get_right()), get_height(node.get_left())) + 1 - node.set_height(h1) + + node.height = my_max(get_height(node.right), get_height(node.left)) + 1 return node -def get_right_most(root: MyNode) -> Any: - while True: - right_child = root.get_right() - if right_child is None: - break - root = right_child - return root.get_data() +def get_left_most(root: MyNode) -> Any: + while root.left: + root = root.left + return root.data + +def del_node(root: MyNode | None, data: Any) -> MyNode | None: + if root is None: + return None -def get_left_most(root: MyNode) -> Any: - while True: - left_child = root.get_left() - if left_child is None: - break - root = left_child - return root.get_data() - - -def del_node(root: MyNode, data: Any) -> MyNode | None: - left_child = root.get_left() - right_child = root.get_right() - if root.get_data() == data: - if left_child is not None and right_child is not None: - temp_data = get_left_most(right_child) - root.set_data(temp_data) - root.set_right(del_node(right_child, temp_data)) - elif left_child is not None: - root = left_child - elif right_child is not None: - root = right_child - else: - return None - elif root.get_data() > data: - if left_child is None: - print("No such data") - return root + if data == root.data: + if root.left and root.right: + root.data = get_left_most(root.right) + root.right = del_node(root.right, root.data) else: - root.set_left(del_node(left_child, data)) - # root.get_data() < data - elif right_child is None: - return root + return root.left or root.right + elif data < root.data: + root.left = del_node(root.left, data) else: - root.set_right(del_node(right_child, data)) + root.right = del_node(root.right, data) + + if root.left is None and root.right is None: + root.height = 1 + return root - # Re-fetch left_child and right_child references - left_child = root.get_left() - right_child = root.get_right() + left_height = get_height(root.left) + right_height = get_height(root.right) - if get_height(right_child) - get_height(left_child) == 2: - assert right_child is not None - if get_height(right_child.get_right()) > get_height(right_child.get_left()): - root = left_rotation(root) - else: - root = rl_rotation(root) - elif get_height(right_child) - get_height(left_child) == -2: - assert left_child is not None - if get_height(left_child.get_left()) > get_height(left_child.get_right()): - root = right_rotation(root) - else: - root = lr_rotation(root) - height = my_max(get_height(root.get_right()), get_height(root.get_left())) + 1 - root.set_height(height) + if right_height - left_height == 2: + right_right = get_height(root.right.right) if root.right else 0 + right_left = get_height(root.right.left) if root.right else 0 + root = left_rotation(root) if right_right > right_left else rl_rotation(root) + elif left_height - right_height == 2: + left_left = get_height(root.left.left) if root.left else 0 + left_right = get_height(root.left.right) if root.left else 0 + root = right_rotation(root) if left_left > left_right else lr_rotation(root) + + root.height = my_max(get_height(root.right), get_height(root.left)) + 1 return root -class AVLtree: - """ - An AVL tree doctest - Examples: - >>> t = AVLtree() - >>> t.insert(4) - insert:4 - >>> print(str(t).replace(" \\n","\\n")) - 4 - ************************************* - >>> t.insert(2) - insert:2 - >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) - 4 - 2 * - ************************************* - >>> t.insert(3) - insert:3 - right rotation node: 2 - left rotation node: 4 - >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) - 3 - 2 4 - ************************************* - >>> t.get_height() - 2 - >>> t.del_node(3) - delete:3 - >>> print(str(t).replace(" \\n","\\n").replace(" \\n","\\n")) - 4 - 2 * - ************************************* - """ +class AVLTree: + __slots__ = ("root",) def __init__(self) -> None: self.root: MyNode | None = None @@ -283,67 +166,55 @@ def get_height(self) -> int: return get_height(self.root) def insert(self, data: Any) -> None: - print("insert:" + str(data)) self.root = insert_node(self.root, data) - def del_node(self, data: Any) -> None: - print("delete:" + str(data)) - if self.root is None: - print("Tree is empty!") - return + def delete(self, data: Any) -> None: self.root = del_node(self.root, data) - def __str__( - self, - ) -> str: # a level traversale, gives a more intuitive look on the tree - output = "" - q = MyQueue() - q.push(self.root) - layer = self.get_height() - if layer == 0: - return output - cnt = 0 - while not q.is_empty(): - node = q.pop() - space = " " * int(math.pow(2, layer - 1)) - output += space - if node is None: - output += "*" - q.push(None) - q.push(None) + def __str__(self) -> str: + if self.root is None: + return "" + + levels = [] + queue: list[MyNode | None] = [self.root] + + while queue: + current = [] + next_level: list[MyNode | None] = [] + + for node in queue: + if node: + current.append(str(node.data)) + next_level.append(node.left) + next_level.append(node.right) + else: + current.append("*") + next_level.append(None) + next_level.append(None) + + if any(node is not None for node in next_level): + levels.append(" ".join(current)) + queue = next_level else: - output += str(node.get_data()) - q.push(node.get_left()) - q.push(node.get_right()) - output += space - cnt = cnt + 1 - for i in range(100): - if cnt == math.pow(2, i) - 1: - layer = layer - 1 - if layer == 0: - output += "\n*************************************" - return output - output += "\n" - break - output += "\n*************************************" - return output - - -def _test() -> None: - import doctest + if current: + levels.append(" ".join(current)) + break - doctest.testmod() + return "\n".join(levels) + "\n" + "*" * 36 -if __name__ == "__main__": - _test() - t = AVLtree() +def test_avl_tree() -> None: + t = AVLTree() lst = list(range(10)) random.shuffle(lst) + for i in lst: t.insert(i) - print(str(t)) + random.shuffle(lst) for i in lst: - t.del_node(i) - print(str(t)) + t.delete(i) + + +if __name__ == "__main__": + test_avl_tree() diff --git a/data_structures/binary_tree/binary_search_tree.py b/data_structures/binary_tree/binary_search_tree.py index 3f214d0113a4..bc27cd3af5e4 100644 --- a/data_structures/binary_tree/binary_search_tree.py +++ b/data_structures/binary_tree/binary_search_tree.py @@ -91,9 +91,9 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any, Self +from pprint import pformat @dataclass @@ -101,52 +101,47 @@ class Node: value: int left: Node | None = None right: Node | None = None - parent: Node | None = None # Added in order to delete a node easier + parent: Node | None = None # For easier deletion + + @property + def is_right(self) -> bool: + return bool(self.parent and self is self.parent.right) def __iter__(self) -> Iterator[int]: - """ - >>> list(Node(0)) - [0] - >>> list(Node(0, Node(-1), Node(1), None)) - [-1, 0, 1] - """ - yield from self.left or [] + if self.left: + yield from self.left yield self.value - yield from self.right or [] + if self.right: + yield from self.right def __repr__(self) -> str: - from pprint import pformat - if self.left is None and self.right is None: return str(self.value) return pformat({f"{self.value}": (self.left, self.right)}, indent=1) - @property - def is_right(self) -> bool: - return bool(self.parent and self is self.parent.right) - @dataclass class BinarySearchTree: root: Node | None = None def __bool__(self) -> bool: - return bool(self.root) + return self.root is not None def __iter__(self) -> Iterator[int]: - yield from self.root or [] + if self.root: + yield from self.root + else: + yield from () def __str__(self) -> str: - """ - Return a string of all the Nodes using in order traversal - """ - return str(self.root) + return str(self.root) if self.root else "Empty tree" def __reassign_nodes(self, node: Node, new_children: Node | None) -> None: - if new_children is not None: # reset its kids + if new_children is not None: new_children.parent = node.parent - if node.parent is not None: # reset its parent - if node.is_right: # If it is the right child + + if node.parent is not None: + if node.is_right: node.parent.right = new_children else: node.parent.left = new_children @@ -154,200 +149,117 @@ def __reassign_nodes(self, node: Node, new_children: Node | None) -> None: self.root = new_children def empty(self) -> bool: - """ - Returns True if the tree does not have any element(s). - False if the tree has element(s). - - >>> BinarySearchTree().empty() - True - >>> BinarySearchTree().insert(1).empty() - False - >>> BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7).empty() - False - """ - return not self.root - - def __insert(self, value) -> None: - """ - Insert a new node in Binary Search Tree with value label - """ - new_node = Node(value) # create a new Node - if self.empty(): # if Tree is empty - self.root = new_node # set its root - else: # Tree is not empty - parent_node = self.root # from root - if parent_node is None: - return - while True: # While we don't get to a leaf - if value < parent_node.value: # We go left - if parent_node.left is None: - parent_node.left = new_node # We insert the new node in a leaf - break - else: - parent_node = parent_node.left - elif parent_node.right is None: + return self.root is None + + def __insert(self, value: int) -> None: + new_node = Node(value) + if self.empty(): + self.root = new_node + return + + parent_node = self.root + while parent_node: + if value < parent_node.value: + if parent_node.left is None: + parent_node.left = new_node + new_node.parent = parent_node + return + parent_node = parent_node.left + else: + if parent_node.right is None: parent_node.right = new_node - break - else: - parent_node = parent_node.right - new_node.parent = parent_node + new_node.parent = parent_node + return + parent_node = parent_node.right - def insert(self, *values) -> Self: + def insert(self, *values: int) -> BinarySearchTree: for value in values: self.__insert(value) return self - def search(self, value) -> Node | None: - """ - >>> tree = BinarySearchTree().insert(10, 20, 30, 40, 50) - >>> tree.search(10) - {'10': (None, {'20': (None, {'30': (None, {'40': (None, 50)})})})} - >>> tree.search(20) - {'20': (None, {'30': (None, {'40': (None, 50)})})} - >>> tree.search(30) - {'30': (None, {'40': (None, 50)})} - >>> tree.search(40) - {'40': (None, 50)} - >>> tree.search(50) - 50 - >>> tree.search(5) is None # element not present - True - >>> tree.search(0) is None # element not present - True - >>> tree.search(-5) is None # element not present - True - >>> BinarySearchTree().search(10) - Traceback (most recent call last): - ... - IndexError: Warning: Tree is empty! please use another. - """ - + def search(self, value: int) -> Node | None: if self.empty(): raise IndexError("Warning: Tree is empty! please use another.") - else: - node = self.root - # use lazy evaluation here to avoid NoneType Attribute error - while node is not None and node.value is not value: - node = node.left if value < node.value else node.right - return node + + node = self.root + while node is not None and node.value != value: + # 修复 SIM108: 使用三元表达式替代 if-else 块 + node = node.left if value < node.value else node.right + return node def get_max(self, node: Node | None = None) -> Node | None: - """ - We go deep on the right branch - - >>> BinarySearchTree().insert(10, 20, 30, 40, 50).get_max() - 50 - >>> BinarySearchTree().insert(-5, -1, 0.1, -0.3, -4.5).get_max() - {'0.1': (-0.3, None)} - >>> BinarySearchTree().insert(1, 78.3, 30, 74.0, 1).get_max() - {'78.3': ({'30': (1, 74.0)}, None)} - >>> BinarySearchTree().insert(1, 783, 30, 740, 1).get_max() - {'783': ({'30': (1, 740)}, None)} - """ if node is None: - if self.root is None: - return None node = self.root + if node is None: + return None - if not self.empty(): - while node.right is not None: - node = node.right + while node.right is not None: + node = node.right return node def get_min(self, node: Node | None = None) -> Node | None: - """ - We go deep on the left branch - - >>> BinarySearchTree().insert(10, 20, 30, 40, 50).get_min() - {'10': (None, {'20': (None, {'30': (None, {'40': (None, 50)})})})} - >>> BinarySearchTree().insert(-5, -1, 0, -0.3, -4.5).get_min() - {'-5': (None, {'-1': (-4.5, {'0': (-0.3, None)})})} - >>> BinarySearchTree().insert(1, 78.3, 30, 74.0, 1).get_min() - {'1': (None, {'78.3': ({'30': (1, 74.0)}, None)})} - >>> BinarySearchTree().insert(1, 783, 30, 740, 1).get_min() - {'1': (None, {'783': ({'30': (1, 740)}, None)})} - """ if node is None: node = self.root - if self.root is None: + if node is None: return None - if not self.empty(): - node = self.root - while node.left is not None: - node = node.left + + while node.left is not None: + node = node.left return node def remove(self, value: int) -> None: - # Look for the node with that label node = self.search(value) if node is None: - msg = f"Value {value} not found" - raise ValueError(msg) + error_msg = f"Value {value} not found" + raise ValueError(error_msg) - if node.left is None and node.right is None: # If it has no children + if node.left is None and node.right is None: self.__reassign_nodes(node, None) - elif node.left is None: # Has only right children + elif node.left is None: self.__reassign_nodes(node, node.right) - elif node.right is None: # Has only left children + elif node.right is None: self.__reassign_nodes(node, node.left) else: - predecessor = self.get_max( - node.left - ) # Gets the max value of the left branch - self.remove(predecessor.value) # type: ignore[union-attr] - node.value = ( - predecessor.value # type: ignore[union-attr] - ) # Assigns the value to the node to delete and keep tree structure - - def preorder_traverse(self, node: Node | None) -> Iterable: + predecessor = self.get_max(node.left) + if predecessor is not None: + self.remove(predecessor.value) + node.value = predecessor.value + + def preorder_traverse(self, node: Node | None) -> Iterator[Node]: if node is not None: - yield node # Preorder Traversal + yield node yield from self.preorder_traverse(node.left) yield from self.preorder_traverse(node.right) - def traversal_tree(self, traversal_function=None) -> Any: - """ - This function traversal the tree. - You can pass a function to traversal the tree as needed by client code - """ + def traversal_tree(self, traversal_function=None) -> Iterator[Node]: if traversal_function is None: return self.preorder_traverse(self.root) - else: - return traversal_function(self.root) + return traversal_function(self.root) - def inorder(self, arr: list, node: Node | None) -> None: - """Perform an inorder traversal and append values of the nodes to - a list named arr""" + def inorder(self, arr: list[int], node: Node | None) -> None: if node: self.inorder(arr, node.left) arr.append(node.value) self.inorder(arr, node.right) def find_kth_smallest(self, k: int, node: Node) -> int: - """Return the kth smallest element in a binary search tree""" arr: list[int] = [] - self.inorder(arr, node) # append all values to list using inorder traversal + self.inorder(arr, node) return arr[k - 1] def inorder(curr_node: Node | None) -> list[Node]: - """ - inorder (left, self, right) - """ - node_list = [] - if curr_node is not None: - node_list = [*inorder(curr_node.left), curr_node, *inorder(curr_node.right)] - return node_list + """Inorder traversal (left, self, right)""" + if curr_node is None: + return [] + return [*inorder(curr_node.left), curr_node, *inorder(curr_node.right)] def postorder(curr_node: Node | None) -> list[Node]: - """ - postOrder (left, right, self) - """ - node_list = [] - if curr_node is not None: - node_list = postorder(curr_node.left) + postorder(curr_node.right) + [curr_node] - return node_list + """Postorder traversal (left, right, self)""" + if curr_node is None: + return [] + return [*postorder(curr_node.left), *postorder(curr_node.right), curr_node] if __name__ == "__main__": diff --git a/data_structures/binary_tree/diff_views_of_binary_tree.py b/data_structures/binary_tree/diff_views_of_binary_tree.py index 3198d8065918..450c60a19373 100644 --- a/data_structures/binary_tree/diff_views_of_binary_tree.py +++ b/data_structures/binary_tree/diff_views_of_binary_tree.py @@ -173,7 +173,6 @@ def binary_tree_bottom_side_view(root: TreeNode) -> list[int]: >>> binary_tree_bottom_side_view(None) [] """ - from collections import defaultdict def breadth_first_search(root: TreeNode, bottom_view: list[int]) -> None: """ diff --git a/data_structures/stacks/stack_with_doubly_linked_list.py b/data_structures/stacks/stack_with_doubly_linked_list.py index 50c5236e073c..d173d53f7bd2 100644 --- a/data_structures/stacks/stack_with_doubly_linked_list.py +++ b/data_structures/stacks/stack_with_doubly_linked_list.py @@ -1,22 +1,20 @@ -# A complete working Python program to demonstrate all -# stack operations using a doubly linked list - +# Complete Python program demonstrating stack operations using a doubly linked list from __future__ import annotations -from typing import Generic, TypeVar - -T = TypeVar("T") +class Node[T]: + """Node class for doubly linked list""" -class Node(Generic[T]): def __init__(self, data: T): - self.data = data # Assign data - self.next: Node[T] | None = None # Initialize next as null - self.prev: Node[T] | None = None # Initialize prev as null + self.data = data # Node data + self.next: Node[T] | None = None # Reference to next node + self.prev: Node[T] | None = None # Reference to previous node -class Stack(Generic[T]): +class Stack[T]: """ + Stack implementation using doubly linked list + >>> stack = Stack() >>> stack.is_empty() True @@ -42,89 +40,76 @@ class Stack(Generic[T]): """ def __init__(self) -> None: - self.head: Node[T] | None = None + self.head: Node[T] | None = None # Top of stack def push(self, data: T) -> None: - """add a Node to the stack""" + """Push element onto stack""" if self.head is None: self.head = Node(data) else: new_node = Node(data) + # Insert new node at head self.head.prev = new_node new_node.next = self.head - new_node.prev = None self.head = new_node def pop(self) -> T | None: - """pop the top element off the stack""" + """Pop element from top of stack""" if self.head is None: return None - else: - assert self.head is not None - temp = self.head.data - self.head = self.head.next - if self.head is not None: - self.head.prev = None - return temp + + # Remove and return head node data + temp = self.head.data + self.head = self.head.next + if self.head is not None: + self.head.prev = None # Clear prev reference for new head + return temp def top(self) -> T | None: - """return the top element of the stack""" + """Peek at top element without removing""" return self.head.data if self.head is not None else None def __len__(self) -> int: - temp = self.head + """Return number of elements in stack""" count = 0 - while temp is not None: + current = self.head + while current: count += 1 - temp = temp.next + current = current.next return count def is_empty(self) -> bool: + """Check if stack is empty""" return self.head is None def print_stack(self) -> None: + """Print all stack elements""" print("stack elements are:") - temp = self.head - while temp is not None: - print(temp.data, end="->") - temp = temp.next + current = self.head + while current: + print(current.data, end="->") + current = current.next -# Code execution starts here +# Program entry point if __name__ == "__main__": - # Start with the empty stack - stack: Stack[int] = Stack() + stack: Stack[int] = Stack() # Create integer stack - # Insert 4 at the beginning. So stack becomes 4->None print("Stack operations using Doubly LinkedList") + # Push elements onto stack stack.push(4) - - # Insert 5 at the beginning. So stack becomes 4->5->None stack.push(5) - - # Insert 6 at the beginning. So stack becomes 4->5->6->None stack.push(6) - - # Insert 7 at the beginning. So stack becomes 4->5->6->7->None stack.push(7) - # Print the stack - stack.print_stack() - - # Print the top element - print("\nTop element is ", stack.top()) + stack.print_stack() # Print current stack - # Print the stack size - print("Size of the stack is ", len(stack)) + print("\nTop element is", stack.top()) # Show top element + print("Size of stack is", len(stack)) # Show size - # pop the top element + # Pop two elements stack.pop() - - # pop the top element stack.pop() - # two elements have now been popped off - stack.print_stack() - - # Print True if the stack is empty else False - print("\nstack is empty:", stack.is_empty()) + stack.print_stack() # Print modified stack + print("\nStack is empty:", stack.is_empty()) # Check emptiness diff --git a/digital_image_processing/test_digital_image_processing.py b/digital_image_processing/test_digital_image_processing.py index d1200f4d65ca..2e6ccec89e93 100644 --- a/digital_image_processing/test_digital_image_processing.py +++ b/digital_image_processing/test_digital_image_processing.py @@ -2,6 +2,8 @@ PyTest's for Digital Image Processing """ +from os import getenv + import numpy as np from cv2 import COLOR_BGR2GRAY, cvtColor, imread from numpy import array, uint8 @@ -19,21 +21,23 @@ from digital_image_processing.filters import sobel_filter as sob from digital_image_processing.resize import resize as rs -img = imread(r"digital_image_processing/image_data/lena_small.jpg") +# Define common image paths +LENA_SMALL_PATH = "digital_image_processing/image_data/lena_small.jpg" +LENA_PATH = "digital_image_processing/image_data/lena.jpg" + +img = imread(LENA_SMALL_PATH) gray = cvtColor(img, COLOR_BGR2GRAY) # Test: convert_to_negative() def test_convert_to_negative(): negative_img = cn.convert_to_negative(img) - # assert negative_img array for at least one True assert negative_img.any() # Test: change_contrast() def test_change_contrast(): - with Image.open("digital_image_processing/image_data/lena_small.jpg") as img: - # Work around assertion for response + with Image.open(LENA_SMALL_PATH) as img: assert str(cc.change_contrast(img, 110)).startswith( " int: +from __future__ import annotations + +import doctest +from typing import overload + + +@overload +def aliquot_sum(input_num: int) -> int: ... +@overload +def aliquot_sum(input_num: int, return_factors: bool) -> tuple[int, list[int]]: ... + + +def aliquot_sum( + input_num: int, return_factors: bool = False +) -> int | tuple[int, list[int]]: """ - Finds the aliquot sum of an input integer, where the - aliquot sum of a number n is defined as the sum of all - natural numbers less than n that divide n evenly. For - example, the aliquot sum of 15 is 1 + 3 + 5 = 9. This is - a simple O(n) implementation. - @param input_num: a positive integer whose aliquot sum is to be found - @return: the aliquot sum of input_num, if input_num is positive. - Otherwise, raise a ValueError - Wikipedia Explanation: https://en.wikipedia.org/wiki/Aliquot_sum - - >>> aliquot_sum(15) - 9 - >>> aliquot_sum(6) - 6 - >>> aliquot_sum(-1) - Traceback (most recent call last): - ... - ValueError: Input must be positive - >>> aliquot_sum(0) - Traceback (most recent call last): - ... - ValueError: Input must be positive - >>> aliquot_sum(1.6) - Traceback (most recent call last): - ... - ValueError: Input must be an integer - >>> aliquot_sum(12) - 16 - >>> aliquot_sum(1) - 0 - >>> aliquot_sum(19) - 1 + Calculates the aliquot sum of a positive integer. + The aliquot sum is the sum of all proper divisors of a number. + + Args: + input_num: Positive integer + return_factors: If True, returns (sum, sorted_factor_list) + + Returns: + Aliquot sum or (sum, factors) if return_factors=True + + Raises: + TypeError: If input not integer + ValueError: If input not positive + + Examples: + >>> aliquot_sum(15) + 9 + >>> aliquot_sum(15, True) + (9, [1, 3, 5]) """ + # Validate input if not isinstance(input_num, int): - raise ValueError("Input must be an integer") + raise TypeError("Input must be an integer") if input_num <= 0: - raise ValueError("Input must be positive") - return sum( - divisor for divisor in range(1, input_num // 2 + 1) if input_num % divisor == 0 - ) + raise ValueError("Input must be positive integer") + # Special case: 1 has no proper divisors + if input_num == 1: + return (0, []) if return_factors else 0 -if __name__ == "__main__": - import doctest + # Initialize factors and total + factors = [1] + total = 1 + sqrt_num = int(input_num**0.5) + + # Find factors efficiently + for divisor in range(2, sqrt_num + 1): + if input_num % divisor == 0: + factors.append(divisor) + total += divisor + complement = input_num // divisor + if complement != divisor: + factors.append(complement) + total += complement + + factors.sort() + return (total, factors) if return_factors else total + +def classify_number(n: int) -> str: + """ + Classifies number based on aliquot sum: + - Perfect: sum = number + - Abundant: sum > number + - Deficient: sum < number + + Examples: + >>> classify_number(6) + 'Perfect' + >>> classify_number(12) + 'Abundant' + """ + if n <= 0: + raise ValueError("Input must be positive integer") + if n == 1: + return "Deficient" + + s = aliquot_sum(n) # Always returns int + if s == n: + return "Perfect" + return "Abundant" if s > n else "Deficient" + + +if __name__ == "__main__": doctest.testmod() + + print("Aliquot sum of 28:", aliquot_sum(28)) + + # Handle tuple return explicitly + result = aliquot_sum(28, True) + if isinstance(result, tuple): + print("Factors of 28:", result[1]) + + print("Classification of 28:", classify_number(28)) + + # Large number test + try: + print("Aliquot sum for 10^9:", aliquot_sum(10**9)) + except (TypeError, ValueError) as e: + print(f"Error: {e}") diff --git a/other/lfu_cache.py b/other/lfu_cache.py index 5a143c739b9d..4dc485fc00e4 100644 --- a/other/lfu_cache.py +++ b/other/lfu_cache.py @@ -1,13 +1,9 @@ from __future__ import annotations from collections.abc import Callable -from typing import Generic, TypeVar -T = TypeVar("T") -U = TypeVar("U") - -class DoubleLinkedListNode(Generic[T, U]): +class DoubleLinkedListNode[T, U]: """ Double Linked List Node built specifically for LFU Cache @@ -30,7 +26,7 @@ def __repr__(self) -> str: ) -class DoubleLinkedList(Generic[T, U]): +class DoubleLinkedList[T, U]: """ Double Linked List built specifically for LFU Cache @@ -76,7 +72,6 @@ class DoubleLinkedList(Generic[T, U]): Node: key: 2, val: 20, freq: 1, has next: True, has prev: True, Node: key: None, val: None, freq: 0, has next: False, has prev: True - >>> # Attempt to remove node not on list >>> removed_node = dll.remove(first_node) >>> removed_node is None @@ -161,7 +156,7 @@ def remove( return node -class LFUCache(Generic[T, U]): +class LFUCache[T, U]: """ LFU Cache to store a given capacity of data. Can be used as a stand-alone object or as a function decorator. diff --git a/other/lru_cache.py b/other/lru_cache.py index 4f0c843c86cc..eb8b1d61da43 100644 --- a/other/lru_cache.py +++ b/other/lru_cache.py @@ -1,333 +1,165 @@ from __future__ import annotations from collections.abc import Callable -from typing import Generic, TypeVar +from functools import wraps +from typing import Any, ParamSpec, TypeVar, cast -T = TypeVar("T") -U = TypeVar("U") +P = ParamSpec("P") +R = TypeVar("R") -class DoubleLinkedListNode(Generic[T, U]): - """ - Double Linked List Node built specifically for LRU Cache +class DoubleLinkedListNode: + """Node for LRU Cache""" - >>> DoubleLinkedListNode(1,1) - Node: key: 1, val: 1, has next: False, has prev: False - """ + __slots__ = ("key", "next", "prev", "val") - def __init__(self, key: T | None, val: U | None): + def __init__(self, key: Any, val: Any) -> None: self.key = key self.val = val - self.next: DoubleLinkedListNode[T, U] | None = None - self.prev: DoubleLinkedListNode[T, U] | None = None + self.next: DoubleLinkedListNode | None = None + self.prev: DoubleLinkedListNode | None = None def __repr__(self) -> str: - return ( - f"Node: key: {self.key}, val: {self.val}, " - f"has next: {bool(self.next)}, has prev: {bool(self.prev)}" - ) - - -class DoubleLinkedList(Generic[T, U]): - """ - Double Linked List built specifically for LRU Cache - - >>> dll: DoubleLinkedList = DoubleLinkedList() - >>> dll - DoubleLinkedList, - Node: key: None, val: None, has next: True, has prev: False, - Node: key: None, val: None, has next: False, has prev: True - - >>> first_node = DoubleLinkedListNode(1,10) - >>> first_node - Node: key: 1, val: 10, has next: False, has prev: False - - - >>> dll.add(first_node) - >>> dll - DoubleLinkedList, - Node: key: None, val: None, has next: True, has prev: False, - Node: key: 1, val: 10, has next: True, has prev: True, - Node: key: None, val: None, has next: False, has prev: True - - >>> # node is mutated - >>> first_node - Node: key: 1, val: 10, has next: True, has prev: True - - >>> second_node = DoubleLinkedListNode(2,20) - >>> second_node - Node: key: 2, val: 20, has next: False, has prev: False - - >>> dll.add(second_node) - >>> dll - DoubleLinkedList, - Node: key: None, val: None, has next: True, has prev: False, - Node: key: 1, val: 10, has next: True, has prev: True, - Node: key: 2, val: 20, has next: True, has prev: True, - Node: key: None, val: None, has next: False, has prev: True - - >>> removed_node = dll.remove(first_node) - >>> assert removed_node == first_node - >>> dll - DoubleLinkedList, - Node: key: None, val: None, has next: True, has prev: False, - Node: key: 2, val: 20, has next: True, has prev: True, - Node: key: None, val: None, has next: False, has prev: True + return f"Node(key={self.key}, val={self.val})" - >>> # Attempt to remove node not on list - >>> removed_node = dll.remove(first_node) - >>> removed_node is None - True - - >>> # Attempt to remove head or rear - >>> dll.head - Node: key: None, val: None, has next: True, has prev: False - >>> dll.remove(dll.head) is None - True - - >>> # Attempt to remove head or rear - >>> dll.rear - Node: key: None, val: None, has next: False, has prev: True - >>> dll.remove(dll.rear) is None - True - - - """ +class DoubleLinkedList: + """Double Linked List for LRU Cache""" def __init__(self) -> None: - self.head: DoubleLinkedListNode[T, U] = DoubleLinkedListNode(None, None) - self.rear: DoubleLinkedListNode[T, U] = DoubleLinkedListNode(None, None) - self.head.next, self.rear.prev = self.rear, self.head + # Create sentinel nodes + self.head: DoubleLinkedListNode = DoubleLinkedListNode(None, None) + self.rear: DoubleLinkedListNode = DoubleLinkedListNode(None, None) + # Link sentinel nodes together + self.head.next = self.rear + self.rear.prev = self.head def __repr__(self) -> str: - rep = ["DoubleLinkedList"] - node = self.head - while node.next is not None: - rep.append(str(node)) - node = node.next - rep.append(str(self.rear)) - return ",\n ".join(rep) - - def add(self, node: DoubleLinkedListNode[T, U]) -> None: - """ - Adds the given node to the end of the list (before rear) - """ - - previous = self.rear.prev - - # All nodes other than self.head are guaranteed to have non-None previous - assert previous is not None - - previous.next = node - node.prev = previous + nodes = [] + current = self.head + while current: + nodes.append(repr(current)) + current = current.next + return f"LinkedList({nodes})" + + def add(self, node: DoubleLinkedListNode) -> None: + """Add node before rear""" + prev = self.rear.prev + if prev is None: + return + + # Insert node between prev and rear + prev.next = node + node.prev = prev self.rear.prev = node node.next = self.rear - def remove( - self, node: DoubleLinkedListNode[T, U] - ) -> DoubleLinkedListNode[T, U] | None: - """ - Removes and returns the given node from the list - - Returns None if node.prev or node.next is None - """ - + def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None: + """Remove node from list""" if node.prev is None or node.next is None: return None + # Bypass node node.prev.next = node.next node.next.prev = node.prev + + # Clear node references node.prev = None node.next = None return node -class LRUCache(Generic[T, U]): - """ - LRU Cache to store a given capacity of data. Can be used as a stand-alone object - or as a function decorator. - - >>> cache = LRUCache(2) - - >>> cache.put(1, 1) - >>> cache.put(2, 2) - >>> cache.get(1) - 1 - - >>> cache.list - DoubleLinkedList, - Node: key: None, val: None, has next: True, has prev: False, - Node: key: 2, val: 2, has next: True, has prev: True, - Node: key: 1, val: 1, has next: True, has prev: True, - Node: key: None, val: None, has next: False, has prev: True - - >>> cache.cache # doctest: +NORMALIZE_WHITESPACE - {1: Node: key: 1, val: 1, has next: True, has prev: True, \ - 2: Node: key: 2, val: 2, has next: True, has prev: True} - - >>> cache.put(3, 3) +class LRUCache: + """LRU Cache implementation""" - >>> cache.list - DoubleLinkedList, - Node: key: None, val: None, has next: True, has prev: False, - Node: key: 1, val: 1, has next: True, has prev: True, - Node: key: 3, val: 3, has next: True, has prev: True, - Node: key: None, val: None, has next: False, has prev: True - - >>> cache.cache # doctest: +NORMALIZE_WHITESPACE - {1: Node: key: 1, val: 1, has next: True, has prev: True, \ - 3: Node: key: 3, val: 3, has next: True, has prev: True} - - >>> cache.get(2) is None - True - - >>> cache.put(4, 4) - - >>> cache.get(1) is None - True - - >>> cache.get(3) - 3 - - >>> cache.get(4) - 4 - - >>> cache - CacheInfo(hits=3, misses=2, capacity=2, current size=2) - - >>> @LRUCache.decorator(100) - ... def fib(num): - ... if num in (1, 2): - ... return 1 - ... return fib(num - 1) + fib(num - 2) - - >>> for i in range(1, 100): - ... res = fib(i) - - >>> fib.cache_info() - CacheInfo(hits=194, misses=99, capacity=100, current size=99) - """ - - def __init__(self, capacity: int): - self.list: DoubleLinkedList[T, U] = DoubleLinkedList() + def __init__(self, capacity: int) -> None: + self.list = DoubleLinkedList() self.capacity = capacity - self.num_keys = 0 + self.size = 0 self.hits = 0 - self.miss = 0 - self.cache: dict[T, DoubleLinkedListNode[T, U]] = {} + self.misses = 0 + self.cache: dict[Any, DoubleLinkedListNode] = {} def __repr__(self) -> str: - """ - Return the details for the cache instance - [hits, misses, capacity, current_size] - """ - return ( - f"CacheInfo(hits={self.hits}, misses={self.miss}, " - f"capacity={self.capacity}, current size={self.num_keys})" + f"Cache(hits={self.hits}, misses={self.misses}, " + f"cap={self.capacity}, size={self.size})" ) - def __contains__(self, key: T) -> bool: - """ - >>> cache = LRUCache(1) + def get(self, key: Any) -> Any | None: + """Get value for key""" + if key in self.cache: + self.hits += 1 + node = self.cache[key] + if self.list.remove(node): + self.list.add(node) + return node.val + self.misses += 1 + return None - >>> 1 in cache - False + def put(self, key: Any, value: Any) -> None: + """Set value for key""" + if key in self.cache: + # Update existing node + node = self.cache[key] + if self.list.remove(node): + node.val = value + self.list.add(node) + return + + # Evict LRU item if at capacity + if self.size >= self.capacity: + first_node: DoubleLinkedListNode | None = self.list.head.next + if ( + first_node is not None + and first_node.key is not None + and first_node != self.list.rear + and self.list.remove(first_node) + ): + del self.cache[first_node.key] + self.size -= 1 - >>> cache.put(1, 1) + # Add new node + new_node = DoubleLinkedListNode(key, value) + self.cache[key] = new_node + self.list.add(new_node) + self.size += 1 - >>> 1 in cache - True - """ + def cache_info(self) -> dict[str, Any]: + """Get cache statistics""" + return { + "hits": self.hits, + "misses": self.misses, + "capacity": self.capacity, + "size": self.size, + } - return key in self.cache - def get(self, key: T) -> U | None: - """ - Returns the value for the input key and updates the Double Linked List. - Returns None if key is not present in cache - """ - # Note: pythonic interface would throw KeyError rather than return None +def lru_cache(maxsize: int = 128) -> Callable[[Callable[P, R]], Callable[P, R]]: + """LRU Cache decorator""" - if key in self.cache: - self.hits += 1 - value_node: DoubleLinkedListNode[T, U] = self.cache[key] - node = self.list.remove(self.cache[key]) - assert node == value_node + def decorator(func: Callable[P, R]) -> Callable[P, R]: + cache = LRUCache(maxsize) - # node is guaranteed not None because it is in self.cache - assert node is not None - self.list.add(node) - return node.val - self.miss += 1 - return None + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # Create normalized cache key + key = (args, tuple(sorted(kwargs.items()))) - def put(self, key: T, value: U) -> None: - """ - Sets the value for the input key and updates the Double Linked List - """ + # Try to get cached result + if (cached := cache.get(key)) is not None: + return cast(R, cached) - if key not in self.cache: - if self.num_keys >= self.capacity: - # delete first node (oldest) when over capacity - first_node = self.list.head.next + # Compute and cache result + result = func(*args, **kwargs) + cache.put(key, result) + return result - # guaranteed to have a non-None first node when num_keys > 0 - # explain to type checker via assertions - assert first_node is not None - assert first_node.key is not None - assert ( - self.list.remove(first_node) is not None - ) # node guaranteed to be in list assert node.key is not None + # Attach cache info method + wrapper.cache_info = cache.cache_info # type: ignore[attr-defined] + return wrapper - del self.cache[first_node.key] - self.num_keys -= 1 - self.cache[key] = DoubleLinkedListNode(key, value) - self.list.add(self.cache[key]) - self.num_keys += 1 - - else: - # bump node to the end of the list, update value - node = self.list.remove(self.cache[key]) - assert node is not None # node guaranteed to be in list - node.val = value - self.list.add(node) - - @classmethod - def decorator( - cls, size: int = 128 - ) -> Callable[[Callable[[T], U]], Callable[..., U]]: - """ - Decorator version of LRU Cache - - Decorated function must be function of T -> U - """ - - def cache_decorator_inner(func: Callable[[T], U]) -> Callable[..., U]: - # variable to map the decorator functions to their respective instance - decorator_function_to_instance_map: dict[ - Callable[[T], U], LRUCache[T, U] - ] = {} - - def cache_decorator_wrapper(*args: T) -> U: - if func not in decorator_function_to_instance_map: - decorator_function_to_instance_map[func] = LRUCache(size) - - result = decorator_function_to_instance_map[func].get(args[0]) - if result is None: - result = func(*args) - decorator_function_to_instance_map[func].put(args[0], result) - return result - - def cache_info() -> LRUCache[T, U]: - return decorator_function_to_instance_map[func] - - setattr(cache_decorator_wrapper, "cache_info", cache_info) # noqa: B010 - - return cache_decorator_wrapper - - return cache_decorator_inner + return decorator if __name__ == "__main__":