Skip to content

Commit 9a8ddb2

Browse files
maxwell-aladagostokhos
authored andcommitted
Fully refactored the rod cutting module. (TheAlgorithms#1169)
* changing typo * fully refactored the rod-cutting module * more documentations * rewording
1 parent 491e8ea commit 9a8ddb2

File tree

1 file changed

+183
-47
lines changed

1 file changed

+183
-47
lines changed

dynamic_programming/rod_cutting.py

+183-47
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,193 @@
1-
from typing import List
1+
"""
2+
This module provides two implementations for the rod-cutting problem:
3+
1. A naive recursive implementation which has an exponential runtime
4+
2. Two dynamic programming implementations which have quadratic runtime
25
3-
def rod_cutting(prices: List[int],length: int) -> int:
6+
The rod-cutting problem is the problem of finding the maximum possible revenue
7+
obtainable from a rod of length ``n`` given a list of prices for each integral piece
8+
of the rod. The maximum revenue can thus be obtained by cutting the rod and selling the
9+
pieces separately or not cutting it at all if the price of it is the maximum obtainable.
10+
11+
"""
12+
13+
14+
def naive_cut_rod_recursive(n: int, prices: list):
15+
"""
16+
Solves the rod-cutting problem via naively without using the benefit of dynamic programming.
17+
The results is the same sub-problems are solved several times leading to an exponential runtime
18+
19+
Runtime: O(2^n)
20+
21+
Arguments
22+
-------
23+
n: int, the length of the rod
24+
prices: list, the prices for each piece of rod. ``p[i-i]`` is the
25+
price for a rod of length ``i``
26+
27+
Returns
28+
-------
29+
The maximum revenue obtainable for a rod of length n given the list of prices for each piece.
30+
31+
Examples
32+
--------
33+
>>> naive_cut_rod_recursive(4, [1, 5, 8, 9])
34+
10
35+
>>> naive_cut_rod_recursive(10, [1, 5, 8, 9, 10, 17, 17, 20, 24, 30])
36+
30
37+
"""
38+
39+
_enforce_args(n, prices)
40+
if n == 0:
41+
return 0
42+
max_revue = float("-inf")
43+
for i in range(1, n + 1):
44+
max_revue = max(max_revue, prices[i - 1] + naive_cut_rod_recursive(n - i, prices))
45+
46+
return max_revue
47+
48+
49+
def top_down_cut_rod(n: int, prices: list):
450
"""
5-
Given a rod of length n and array of prices that indicate price at each length.
6-
Determine the maximum value obtainable by cutting up the rod and selling the pieces
7-
8-
>>> rod_cutting([1,5,8,9],4)
51+
Constructs a top-down dynamic programming solution for the rod-cutting problem
52+
via memoization. This function serves as a wrapper for _top_down_cut_rod_recursive
53+
54+
Runtime: O(n^2)
55+
56+
Arguments
57+
--------
58+
n: int, the length of the rod
59+
prices: list, the prices for each piece of rod. ``p[i-i]`` is the
60+
price for a rod of length ``i``
61+
62+
Note
63+
----
64+
For convenience and because Python's lists using 0-indexing, length(max_rev) = n + 1,
65+
to accommodate for the revenue obtainable from a rod of length 0.
66+
67+
Returns
68+
-------
69+
The maximum revenue obtainable for a rod of length n given the list of prices for each piece.
70+
71+
Examples
72+
-------
73+
>>> top_down_cut_rod(4, [1, 5, 8, 9])
974
10
10-
>>> rod_cutting([1,1,1],3)
11-
3
12-
>>> rod_cutting([1,2,3], -1)
13-
Traceback (most recent call last):
14-
ValueError: Given integer must be greater than 1, not -1
15-
>>> rod_cutting([1,2,3], 3.2)
16-
Traceback (most recent call last):
17-
TypeError: Must be int, not float
18-
>>> rod_cutting([], 3)
19-
Traceback (most recent call last):
20-
AssertionError: prices list is shorted than length: 3
21-
22-
23-
24-
Args:
25-
prices: list indicating price at each length, where prices[0] = 0 indicating rod of zero length has no value
26-
length: length of rod
27-
28-
Returns:
29-
Maximum revenue attainable by cutting up the rod in any way.
30-
"""
31-
32-
prices.insert(0, 0)
33-
if not isinstance(length, int):
34-
raise TypeError('Must be int, not {0}'.format(type(length).__name__))
35-
if length < 0:
36-
raise ValueError('Given integer must be greater than 1, not {0}'.format(length))
37-
assert len(prices) - 1 >= length, "prices list is shorted than length: {0}".format(length)
38-
39-
return rod_cutting_recursive(prices, length)
40-
41-
def rod_cutting_recursive(prices: List[int],length: int) -> int:
42-
#base case
43-
if length == 0:
75+
>>> top_down_cut_rod(10, [1, 5, 8, 9, 10, 17, 17, 20, 24, 30])
76+
30
77+
"""
78+
_enforce_args(n, prices)
79+
max_rev = [float("-inf") for _ in range(n + 1)]
80+
return _top_down_cut_rod_recursive(n, prices, max_rev)
81+
82+
83+
def _top_down_cut_rod_recursive(n: int, prices: list, max_rev: list):
84+
"""
85+
Constructs a top-down dynamic programming solution for the rod-cutting problem
86+
via memoization.
87+
88+
Runtime: O(n^2)
89+
90+
Arguments
91+
--------
92+
n: int, the length of the rod
93+
prices: list, the prices for each piece of rod. ``p[i-i]`` is the
94+
price for a rod of length ``i``
95+
max_rev: list, the computed maximum revenue for a piece of rod.
96+
``max_rev[i]`` is the maximum revenue obtainable for a rod of length ``i``
97+
98+
Returns
99+
-------
100+
The maximum revenue obtainable for a rod of length n given the list of prices for each piece.
101+
"""
102+
if max_rev[n] >= 0:
103+
return max_rev[n]
104+
elif n == 0:
44105
return 0
45-
value = float('-inf')
46-
for firstCutLocation in range(1,length+1):
47-
value = max(value, prices[firstCutLocation]+rod_cutting_recursive(prices,length - firstCutLocation))
48-
return value
106+
else:
107+
max_revenue = float("-inf")
108+
for i in range(1, n + 1):
109+
max_revenue = max(max_revenue, prices[i - 1] + _top_down_cut_rod_recursive(n - i, prices, max_rev))
110+
111+
max_rev[n] = max_revenue
112+
113+
return max_rev[n]
114+
115+
116+
def bottom_up_cut_rod(n: int, prices: list):
117+
"""
118+
Constructs a bottom-up dynamic programming solution for the rod-cutting problem
119+
120+
Runtime: O(n^2)
121+
122+
Arguments
123+
----------
124+
n: int, the maximum length of the rod.
125+
prices: list, the prices for each piece of rod. ``p[i-i]`` is the
126+
price for a rod of length ``i``
127+
128+
Returns
129+
-------
130+
The maximum revenue obtainable from cutting a rod of length n given
131+
the prices for each piece of rod p.
132+
133+
Examples
134+
-------
135+
>>> bottom_up_cut_rod(4, [1, 5, 8, 9])
136+
10
137+
>>> bottom_up_cut_rod(10, [1, 5, 8, 9, 10, 17, 17, 20, 24, 30])
138+
30
139+
"""
140+
_enforce_args(n, prices)
141+
142+
# length(max_rev) = n + 1, to accommodate for the revenue obtainable from a rod of length 0.
143+
max_rev = [float("-inf") for _ in range(n + 1)]
144+
max_rev[0] = 0
145+
146+
for i in range(1, n + 1):
147+
max_revenue_i = max_rev[i]
148+
for j in range(1, i + 1):
149+
max_revenue_i = max(max_revenue_i, prices[j - 1] + max_rev[i - j])
150+
151+
max_rev[i] = max_revenue_i
152+
153+
return max_rev[n]
154+
155+
156+
def _enforce_args(n: int, prices: list):
157+
"""
158+
Basic checks on the arguments to the rod-cutting algorithms
159+
160+
n: int, the length of the rod
161+
prices: list, the price list for each piece of rod.
162+
163+
Throws ValueError:
164+
165+
if n is negative or there are fewer items in the price list than the length of the rod
166+
"""
167+
if n < 0:
168+
raise ValueError(f"n must be greater than or equal to 0. Got n = {n}")
169+
170+
if n > len(prices):
171+
raise ValueError(f"Each integral piece of rod must have a corresponding "
172+
f"price. Got n = {n} but length of prices = {len(prices)}")
49173

50174

51175
def main():
52-
assert rod_cutting([1,5,8,9,10,17,17,20,24,30],10) == 30
53-
# print(rod_cutting([],0))
176+
prices = [6, 10, 12, 15, 20, 23]
177+
n = len(prices)
178+
179+
# the best revenue comes from cutting the rod into 6 pieces, each
180+
# of length 1 resulting in a revenue of 6 * 6 = 36.
181+
expected_max_revenue = 36
182+
183+
max_rev_top_down = top_down_cut_rod(n, prices)
184+
max_rev_bottom_up = bottom_up_cut_rod(n, prices)
185+
max_rev_naive = naive_cut_rod_recursive(n, prices)
186+
187+
assert expected_max_revenue == max_rev_top_down
188+
assert max_rev_top_down == max_rev_bottom_up
189+
assert max_rev_bottom_up == max_rev_naive
190+
54191

55192
if __name__ == '__main__':
56193
main()
57-

0 commit comments

Comments
 (0)