diff --git a/data_structures/binary_tree/lazy_segment_tree.py b/data_structures/binary_tree/lazy_segment_tree.py index 461996b87c26..e247249bf232 100644 --- a/data_structures/binary_tree/lazy_segment_tree.py +++ b/data_structures/binary_tree/lazy_segment_tree.py @@ -1,84 +1,93 @@ import math +from typing import List class SegmentTree: - def __init__(self, N): + def __init__(self, N: int) -> None: self.N = N - self.st = [ + self.st: List[int] = [ 0 for i in range(0, 4 * N) ] # approximate the overall size of segment tree with array N - self.lazy = [0 for i in range(0, 4 * N)] # create array to store lazy update - self.flag = [0 for i in range(0, 4 * N)] # flag for lazy update + self.lazy: List[int] = [ + 0 for i in range(0, 4 * N) + ] # create array to store lazy update + self.flag: List[int] = [0 for i in range(0, 4 * N)] # flag for lazy update - def left(self, idx): + def left(self, idx: int) -> int: return idx * 2 - def right(self, idx): + def right(self, idx: int) -> int: return idx * 2 + 1 - def build(self, idx, l, r, A): # noqa: E741 - if l == r: # noqa: E741 - self.st[idx] = A[l - 1] + def build( + self, idx: int, left_element: int, right_element: int, A: List[int] + ) -> None: + if left_element == right_element: + self.st[idx] = A[left_element - 1] else: - mid = (l + r) // 2 - self.build(self.left(idx), l, mid, A) - self.build(self.right(idx), mid + 1, r, A) + mid = (left_element + right_element) // 2 + self.build(self.left(idx), left_element, mid, A) + self.build(self.right(idx), mid + 1, right_element, A) self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)]) # update with O(lg N) (Normal segment tree without lazy update will take O(Nlg N) # for each update) - def update(self, idx, l, r, a, b, val): # noqa: E741 + def update( + self, idx: int, left_element: int, right_element: int, a: int, b: int, val: int + ) -> bool: """ update(1, 1, N, a, b, v) for update val v to [a,b] """ if self.flag[idx] is True: self.st[idx] = self.lazy[idx] self.flag[idx] = False - if l != r: # noqa: E741 + if left_element != right_element: self.lazy[self.left(idx)] = self.lazy[idx] self.lazy[self.right(idx)] = self.lazy[idx] self.flag[self.left(idx)] = True self.flag[self.right(idx)] = True - if r < a or l > b: + if right_element < a or left_element > b: return True - if l >= a and r <= b: # noqa: E741 + if left_element >= a and right_element <= b: self.st[idx] = val - if l != r: # noqa: E741 + if left_element != right_element: self.lazy[self.left(idx)] = val self.lazy[self.right(idx)] = val self.flag[self.left(idx)] = True self.flag[self.right(idx)] = True return True - mid = (l + r) // 2 - self.update(self.left(idx), l, mid, a, b, val) - self.update(self.right(idx), mid + 1, r, a, b, val) + mid = (left_element + right_element) // 2 + self.update(self.left(idx), left_element, mid, a, b, val) + self.update(self.right(idx), mid + 1, right_element, a, b, val) self.st[idx] = max(self.st[self.left(idx)], self.st[self.right(idx)]) return True # query with O(lg N) - def query(self, idx, l, r, a, b): # noqa: E741 + def query( + self, idx: int, left_element: int, right_element: int, a: int, b: int + ) -> int: """ query(1, 1, N, a, b) for query max of [a,b] """ if self.flag[idx] is True: self.st[idx] = self.lazy[idx] self.flag[idx] = False - if l != r: # noqa: E741 + if left_element != right_element: self.lazy[self.left(idx)] = self.lazy[idx] self.lazy[self.right(idx)] = self.lazy[idx] self.flag[self.left(idx)] = True self.flag[self.right(idx)] = True - if r < a or l > b: + if right_element < a or left_element > b: return -math.inf - if l >= a and r <= b: # noqa: E741 + if left_element >= a and right_element <= b: return self.st[idx] - mid = (l + r) // 2 - q1 = self.query(self.left(idx), l, mid, a, b) - q2 = self.query(self.right(idx), mid + 1, r, a, b) + mid = (left_element + right_element) // 2 + q1 = self.query(self.left(idx), left_element, mid, a, b) + q2 = self.query(self.right(idx), mid + 1, right_element, a, b) return max(q1, q2) - def showData(self): + def showData(self) -> None: showList = [] for i in range(1, N + 1): showList += [self.query(1, 1, self.N, i, i)]