In [2]:
# Guide to tackling a Dynamic Programming problem with the help of code

# Step 1: generate some test cases
import random
def problem_instance():
    vals = 'abcdefghijklmnopqrstuvwxyz'
    freq = [random.random() for _ in vals]
    tot = sum(freq)
    freq = [f / tot for f in freq] # normalize
    return vals, freq
vals, freq = problem_instance()
n = len(vals)
In [3]:
# A binary search tree is either:
# - None (the empty tree)
# - (val, left, right) a value, and left/right children

def balanced_tree(arr):
    if len(arr) == 0: return None
    k = len(arr)//2
    return (arr[k], balanced_tree(arr[:k]), balanced_tree(arr[k+1:]))

def bst_search(q, T):
    if T is None:
        return False
    (Tval, Tleft, Tright) = T
    if q == Tval: return True
    if q <  Tval: return bst_search(q, Tleft)
    if q >  Tval: return bst_search(q, Tright)
    
def preorder(T):
    if T is None: return []
    Tval, Tleft, Tright = T
    return [Tval] + preorder(Tleft) + preorder(Tright)

def check_bst(T):
    for q in vals:
        assert(bst_search(q, T) == True)
        
T = balanced_tree(vals)
print('Balanced tree:', preorder(T))
check_bst(T)
Balanced tree: ['n', 'g', 'd', 'b', 'a', 'c', 'f', 'e', 'k', 'i', 'h', 'j', 'm', 'l', 'u', 'r', 'p', 'o', 'q', 't', 's', 'x', 'w', 'v', 'z', 'y']
In [5]:
def bst_search_length(q, T):
    if T is None:
        return 0
    (Tval, Tleft, Tright) = T
    if q == Tval: return 1
    if q <  Tval: return 1 + bst_search_length(q, Tleft)
    if q >  Tval: return 1 + bst_search_length(q, Tright)

def average_cost(T):
    avg = 0
    for i,q in enumerate(vals):
        avg += freq[i] * bst_search_length(q, T)
    return avg

print('Avg cost of balanced tree:', average_cost(T))
Avg cost of balanced tree: 4.020400409220588
In [6]:
# Step 2: Write the recursive (and implicitly memoized) solution

# Optional: Precompute the range Freq table
# F[lo][hi] is the cumulative frequency of vals[lo] to vals[hi] inclusive
F = [[0] * (n+1) for _ in range(n)]
for i in range(n):
    F[i][i] = freq[i]
    for j in range(i+1, n):
        F[i][j] = F[i][j-1] + freq[j]
assert(abs(F[0][n-1] - 1.0) < 0.0001)
# print('freq:', freq)
# print('F:', F)

class Memoize: 
    def __init__(self, func):
        self.tbl = {}
        self.func = func
    def __call__(self, *args):
        if args in self.tbl: return self.tbl[args]
        self.tbl[args] = self.func(*args)
        return self.tbl[args]

@Memoize
# From lo to hi inclusive
def optbst_recursive(lo, hi):
    if lo > hi:
        return 0, None # Empty tree
    bestavg = 1000000 # that's a big tree
    bestT = None
    for k in range(lo, hi + 1):
        l,ltree = optbst_recursive(lo,k-1)
        r,rtree = optbst_recursive(k+1,hi)
        s = F[lo][hi] + l + r
        if s < bestavg:
            bestavg = s
            bestT = (vals[k], ltree, rtree)
    return bestavg, bestT

bestAvg, bestT = optbst_recursive(0, n-1)
print('bestT:', preorder(bestT))
print('Avg cost of bestT:', average_cost(bestT))
bestT: ['j', 'e', 'b', 'a', 'c', 'd', 'g', 'f', 'i', 'h', 'u', 'p', 'm', 'l', 'k', 'n', 'o', 'r', 'q', 's', 't', 'x', 'v', 'w', 'z', 'y']
Avg cost of bestT: 3.704721373772423
In [8]:
# Step 3: Work out the (explicit) dynamic programming solution
def optbst_dynprog():
    OPT = [[None] * (n+1) for _ in range(n)]
    # Compute Cost
    def compute_cost(lo, hi):
        bestavg = 1000000 # that's a big tree
        for k in range(lo, hi + 1):
            l = OPT[lo][k-1] if k > lo else 0
            r = OPT[k+1][hi] if k < hi else 0
            s = F[lo][hi] + l + r
            if s < bestavg:
                bestavg = s
            OPT[lo][hi] = bestavg
    for lo in range(n, -1, -1):
        for hi in range(lo, n):
            compute_cost(lo, hi)
            assert OPT[lo][hi] == optbst_recursive(lo, hi)[0]
    return OPT[0][n-1]
optval = optbst_dynprog()
assert(optval == bestAvg)
In [9]:
# Step 4: Work out a dynamic programming solution that records the solution
def optbst_dynprog2():
    global OPT
    OPT = [[None] * (n+1) for _ in range(n)]
    # Compute Cost
    def compute_cost(lo, hi):
        bestavg = 1000000 # that's a big tree
        bestk = -1  # Keep track of the index k with the best score
        for k in range(lo, hi + 1):
            l,lk = OPT[lo][k-1] if k > lo else (0, -1)
            r,rk = OPT[k+1][hi] if k < hi else (0, -1)
            s = F[lo][hi] + l + r
            if s < bestavg:
                bestavg = s
                bestk = k    # Store the k that gave the best score
        OPT[lo][hi] = bestavg, bestk
    for lo in range(n,-1,-1):
        for hi in range(lo, n):
            compute_cost(lo, hi)
            assert OPT[lo][hi][0] == optbst_recursive(lo, hi)[0]
    return OPT[0][n-1]

# Check the iterative solution and tree itself
best3, T3 = optbst_dynprog2()
assert(best3 == bestAvg)
print(preorder(T))

# Go find the solution in O(n) time
def preorder2(lo, hi, OPT):
    if lo > hi:
        return []
    _, k = OPT[lo][hi]
    return [vals[k]] + preorder2(lo, k-1, OPT) + preorder2(k+1, hi, OPT)
print(preorder2(0, n-1, OPT))
assert(preorder2(0, n-1, OPT) == preorder(bestT))
['n', 'g', 'd', 'b', 'a', 'c', 'f', 'e', 'k', 'i', 'h', 'j', 'm', 'l', 'u', 'r', 'p', 'o', 'q', 't', 's', 'x', 'w', 'v', 'z', 'y']
['j', 'e', 'b', 'a', 'c', 'd', 'g', 'f', 'i', 'h', 'u', 'p', 'm', 'l', 'k', 'n', 'o', 'r', 'q', 's', 't', 'x', 'v', 'w', 'z', 'y']