diff --git a/dynamic_programming/optimal_binary_search_tree.py b/dynamic_programming/optimal_binary_search_tree.py index b4f1181ac11c..cc2eac3c53b5 100644 --- a/dynamic_programming/optimal_binary_search_tree.py +++ b/dynamic_programming/optimal_binary_search_tree.py @@ -102,7 +102,7 @@ def find_optimal_binary_search_tree(nodes): # This 2D array stores the overall tree cost (which's as minimized as possible); # for a single key, cost is equal to frequency of the key. dp = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)] - # sum[i][j] stores the sum of key frequencies between i and j inclusive in nodes + # total[i][j] stores the sum of key frequencies between i and j inclusive in nodes # array total = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)] # stores tree roots that will be used later for constructing binary search tree @@ -115,11 +115,17 @@ def find_optimal_binary_search_tree(nodes): dp[i][j] = sys.maxsize # set the value to "infinity" total[i][j] = total[i][j - 1] + freqs[j] - # Apply Knuth's optimization - # Loop without optimization: for r in range(i, j + 1): - for r in range(root[i][j - 1], root[i + 1][j] + 1): # r is a temporal root - left = dp[i][r - 1] if r != i else 0 # optimal cost for left subtree - right = dp[r + 1][j] if r != j else 0 # optimal cost for right subtree + # Apply Knuth's optimization with safe boundary handling + r_start = root[i][j - 1] if j > i else i + r_end = root[i + 1][j] if i < j else j + + # Ensure r_start and r_end are within valid bounds + r_start = max(i, min(r_start, j)) + r_end = min(j, max(r_end, i)) + + for r in range(r_start, r_end + 1): + left = dp[i][r - 1] if r > i else 0 # optimal cost for left subtree + right = dp[r + 1][j] if r < j else 0 # optimal cost for right subtree cost = left + total[i][j] + right if dp[i][j] > cost: