diff --git a/data_structures/linked_list/doubly_linked_list_two.py b/data_structures/linked_list/doubly_linked_list_two.py index e993cc5a20af..3d3bfb0cde30 100644 --- a/data_structures/linked_list/doubly_linked_list_two.py +++ b/data_structures/linked_list/doubly_linked_list_two.py @@ -9,25 +9,19 @@ Delete operation is more efficient """ +from dataclasses import dataclass +from typing import Self + +@dataclass class Node: - def __init__(self, data: int, previous=None, next_node=None): - self.data = data - self.previous = previous - self.next = next_node + data: int + previous: Self | None = None + next: Self | None = None def __str__(self) -> str: return f"{self.data}" - def get_data(self) -> int: - return self.data - - def get_next(self): - return self.next - - def get_previous(self): - return self.previous - class LinkedListIterator: def __init__(self, head): @@ -40,30 +34,30 @@ def __next__(self): if not self.current: raise StopIteration else: - value = self.current.get_data() - self.current = self.current.get_next() + value = self.current.data + self.current = self.current.next return value +@dataclass class LinkedList: - def __init__(self): - self.head = None # First node in list - self.tail = None # Last node in list + head: Node | None = None # First node in list + tail: Node | None = None # Last node in list def __str__(self): current = self.head nodes = [] while current is not None: - nodes.append(current.get_data()) - current = current.get_next() + nodes.append(current.data) + current = current.next return " ".join(str(node) for node in nodes) def __contains__(self, value: int): current = self.head while current: - if current.get_data() == value: + if current.data == value: return True - current = current.get_next() + current = current.next return False def __iter__(self): @@ -71,12 +65,12 @@ def __iter__(self): def get_head_data(self): if self.head: - return self.head.get_data() + return self.head.data return None def get_tail_data(self): if self.tail: - return self.tail.get_data() + return self.tail.data return None def set_head(self, node: Node) -> None: @@ -103,18 +97,20 @@ def insert_before_node(self, node: Node, node_to_insert: Node) -> None: node_to_insert.next = node node_to_insert.previous = node.previous - if node.get_previous() is None: + if node.previous is None: self.head = node_to_insert else: node.previous.next = node_to_insert node.previous = node_to_insert - def insert_after_node(self, node: Node, node_to_insert: Node) -> None: + def insert_after_node(self, node: Node | None, node_to_insert: Node) -> None: + assert node is not None + node_to_insert.previous = node node_to_insert.next = node.next - if node.get_next() is None: + if node.next is None: self.tail = node_to_insert else: node.next.previous = node_to_insert @@ -136,27 +132,27 @@ def insert_at_position(self, position: int, value: int) -> None: def get_node(self, item: int) -> Node: node = self.head while node: - if node.get_data() == item: + if node.data == item: return node - node = node.get_next() + node = node.next raise Exception("Node not found") def delete_value(self, value): if (node := self.get_node(value)) is not None: if node == self.head: - self.head = self.head.get_next() + self.head = self.head.next if node == self.tail: - self.tail = self.tail.get_previous() + self.tail = self.tail.previous self.remove_node_pointers(node) @staticmethod def remove_node_pointers(node: Node) -> None: - if node.get_next(): + if node.next: node.next.previous = node.previous - if node.get_previous(): + if node.previous: node.previous.next = node.next node.next = None