Skip to content

Commit d51a07e

Browse files
committed
Add additional checks to prevent IndexError
1 parent e12456d commit d51a07e

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

dynamic_programming/optimal_binary_search_tree.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def find_optimal_binary_search_tree(nodes):
102102
# This 2D array stores the overall tree cost (which's as minimized as possible);
103103
# for a single key, cost is equal to frequency of the key.
104104
dp = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)]
105-
# sum[i][j] stores the sum of key frequencies between i and j inclusive in nodes
105+
# total[i][j] stores the sum of key frequencies between i and j inclusive in nodes
106106
# array
107107
total = [[freqs[i] if i == j else 0 for j in range(n)] for i in range(n)]
108108
# stores tree roots that will be used later for constructing binary search tree
@@ -115,16 +115,17 @@ def find_optimal_binary_search_tree(nodes):
115115
dp[i][j] = sys.maxsize # set the value to "infinity"
116116
total[i][j] = total[i][j - 1] + freqs[j]
117117

118-
# Apply Knuth's optimization with boundary checking
119-
r_start = max(i, root[i][j - 1] if j > i else i)
120-
r_end = min(j, root[i + 1][j] if i < j else j)
118+
# Apply Knuth's optimization with safe boundary handling
119+
r_start = root[i][j - 1] if j > i else i
120+
r_end = root[i + 1][j] if i < j else j
121121

122-
if r_start > r_end:
123-
r_start, r_end = i, j # fall back to the full range
122+
# Ensure r_start and r_end are within valid bounds
123+
r_start = max(i, min(r_start, j))
124+
r_end = min(j, max(r_end, i))
124125

125126
for r in range(r_start, r_end + 1):
126-
left = dp[i][r - 1] if r != i else 0 # optimal cost for left subtree
127-
right = dp[r + 1][j] if r != j else 0 # optimal cost for right subtree
127+
left = dp[i][r - 1] if r > i else 0 # optimal cost for left subtree
128+
right = dp[r + 1][j] if r < j else 0 # optimal cost for right subtree
128129
cost = left + total[i][j] + right
129130

130131
if dp[i][j] > cost:

0 commit comments

Comments
 (0)