Skip to content

Added DP Solution for Optimal BST Problem #1740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Feb 11, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions dynamic_programming/optimal_bst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/usr/bin/python3

# This Python program provides O(n^2) dynamic programming solution
# to an optimal BST problem.
#
# The goal of the optimal BST problem is to build a low-cost BST for a
# given set of nodes, each with its own key and frequency. The frequency
# of the node is defined as how many time the node is being searched.
# The characteristic of low-cost BSTs is having a faster overall search
# time than other BSTs. The reason for their fast search time is that
# the nodes with high frequencies will be placed near the root of the
# tree while the nodes with low frequencies will be placed near the tree
# leaves thus reducing the search time.

import sys

from random import randint


class Node:
"""BST Node"""

def __init__(self, key, freq):
self.key = key
self.freq = freq


def print_BST(root, key, i, j, parent, is_left):
"""Recursive function to print a BST from a root table."""
if i > j or i < 0 or j > len(root) - 1:
return

if parent == -1:
print(
f"{key[root[i][j]]} is the root of the BST."
) # root does not have a parent
elif is_left:
print(f"{key[root[i][j]]} is the left child of key {parent}.")
else:
print(f"{key[root[i][j]]} is the right child of key {parent}.")

print_BST(
root, key, i, root[i][j] - 1, key[root[i][j]], True
) # recur to left child
print_BST(
root, key, root[i][j] + 1, j, key[root[i][j]], False
) # recur to right child


def find_optimal_BST(nodes):
"""
Precondition: Node keys are sorted in an increasing order.

This function calculates and prints the optimal BST.
The dynamic programming algorithm below runs in O(n^2) time.
Implemented from CLRS book.

>>> nodes = [Node(12, 8), Node(10, 34), Node(20, 50), Node(42, 3), Node(25, 40), Node(37, 30)]
>>> nodes.sort(key=lambda node: node.key)
>>> find_optimal_BST(nodes)
The cost of optimal BST is 324.
20 is the root of the BST.
10 is the left child of key 20.
12 is the right child of key 10.
25 is the right child of key 20.
37 is the right child of key 25.
42 is the right child of key 37.
"""
n = len(nodes)

key = [nodes[i].key for i in range(n)]
freq = [nodes[i].freq for i in range(n)]

# This 2D array stores the overall tree cost (which's as minimized as possible); for a single key, cost is equal to frequency of the key.
dp = [[freq[i] if i == j else 0 for j in range(n)] for i in range(n)]
# sum[i][j] stores the sum of key frequencies between i and j inclusive in nodes array
sum = [[freq[i] if i == j else 0 for j in range(n)] for i in range(n)]
# stores tree roots used for constructing BST later
root = [[i if i == j else 0 for j in range(n)] for i in range(n)]

for l in range(2, n + 1): # l is an interval length
for i in range(n - l + 1):
j = i + l - 1

dp[i][j] = sys.maxsize # set the value to "infinity"
sum[i][j] = sum[i][j - 1] + freq[j] # (sum in range [i...j]) = (sum in range [i...j - 1]) + freq[j]

# Apply Knuth's optimization
# Loop without optimization: for r in range(i, j + 1):
for r in range(root[i][j - 1], root[i + 1][j] + 1): # r is a temporal root
left = dp[i][r - 1] if r != i else 0 # optimal cost for left subtree
right = dp[r + 1][j] if r != j else 0 # optimal cost for right subtree
cost = left + sum[i][j] + right

if dp[i][j] > cost:
dp[i][j] = cost
root[i][j] = r

print(f"The cost of optimal BST is {dp[0][n - 1]}.")
print_BST(root, key, 0, n - 1, -1, False)


def main():
# A sample BST
nodes = [Node(i, randint(1, 50)) for i in range(10, 0, -1)]

# Tree nodes must be sorted first, the code below sorts the keys in
# increasing order and rearrange its frequencies accordingly.
nodes.sort(key=lambda node: node.key)

find_optimal_BST(nodes)


if __name__ == "__main__":
# import doctest
# doctest.testmod()
main()