Skip to content

Commit e6aae1c

Browse files
Shailaputrigithub-actionsPooja Sharmacclausspre-commit-ci[bot]
authored
Dynamic programming/matrix chain multiplication (#10562)
* updating DIRECTORY.md * spell changes * updating DIRECTORY.md * real world applications * updating DIRECTORY.md * Update matrix_chain_multiplication.py Add a non-dp solution with benchmarks. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update matrix_chain_multiplication.py * Update matrix_chain_multiplication.py * Update matrix_chain_multiplication.py --------- Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com> Co-authored-by: Pooja Sharma <[email protected]> Co-authored-by: Christian Clauss <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d00888d commit e6aae1c

File tree

2 files changed

+147
-1
lines changed

2 files changed

+147
-1
lines changed

Diff for: DIRECTORY.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
* [Permutations](data_structures/arrays/permutations.py)
183183
* [Prefix Sum](data_structures/arrays/prefix_sum.py)
184184
* [Product Sum](data_structures/arrays/product_sum.py)
185+
* [Sparse Table](data_structures/arrays/sparse_table.py)
185186
* Binary Tree
186187
* [Avl Tree](data_structures/binary_tree/avl_tree.py)
187188
* [Basic Binary Tree](data_structures/binary_tree/basic_binary_tree.py)
@@ -340,6 +341,7 @@
340341
* [Longest Increasing Subsequence O(Nlogn)](dynamic_programming/longest_increasing_subsequence_o(nlogn).py)
341342
* [Longest Palindromic Subsequence](dynamic_programming/longest_palindromic_subsequence.py)
342343
* [Longest Sub Array](dynamic_programming/longest_sub_array.py)
344+
* [Matrix Chain Multiplication](dynamic_programming/matrix_chain_multiplication.py)
343345
* [Matrix Chain Order](dynamic_programming/matrix_chain_order.py)
344346
* [Max Non Adjacent Sum](dynamic_programming/max_non_adjacent_sum.py)
345347
* [Max Product Subarray](dynamic_programming/max_product_subarray.py)
@@ -370,6 +372,7 @@
370372
* [Builtin Voltage](electronics/builtin_voltage.py)
371373
* [Carrier Concentration](electronics/carrier_concentration.py)
372374
* [Charging Capacitor](electronics/charging_capacitor.py)
375+
* [Charging Inductor](electronics/charging_inductor.py)
373376
* [Circular Convolution](electronics/circular_convolution.py)
374377
* [Coulombs Law](electronics/coulombs_law.py)
375378
* [Electric Conductivity](electronics/electric_conductivity.py)
@@ -524,6 +527,7 @@
524527
* [Simplex](linear_programming/simplex.py)
525528

526529
## Machine Learning
530+
* [Apriori Algorithm](machine_learning/apriori_algorithm.py)
527531
* [Astar](machine_learning/astar.py)
528532
* [Data Transformations](machine_learning/data_transformations.py)
529533
* [Decision Tree](machine_learning/decision_tree.py)
@@ -554,7 +558,6 @@
554558
* [Word Frequency Functions](machine_learning/word_frequency_functions.py)
555559
* [Xgboost Classifier](machine_learning/xgboost_classifier.py)
556560
* [Xgboost Regressor](machine_learning/xgboost_regressor.py)
557-
* [Apriori Algorithm](machine_learning/apriori_algorithm.py)
558561

559562
## Maths
560563
* [Abs](maths/abs.py)

Diff for: dynamic_programming/matrix_chain_multiplication.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
Find the minimum number of multiplications needed to multiply chain of matrices.
3+
Reference: https://www.geeksforgeeks.org/matrix-chain-multiplication-dp-8/
4+
5+
The algorithm has interesting real-world applications. Example:
6+
1. Image transformations in Computer Graphics as images are composed of matrix.
7+
2. Solve complex polynomial equations in the field of algebra using least processing
8+
power.
9+
3. Calculate overall impact of macroeconomic decisions as economic equations involve a
10+
number of variables.
11+
4. Self-driving car navigation can be made more accurate as matrix multiplication can
12+
accurately determine position and orientation of obstacles in short time.
13+
14+
Python doctests can be run with the following command:
15+
python -m doctest -v matrix_chain_multiply.py
16+
17+
Given a sequence arr[] that represents chain of 2D matrices such that the dimension of
18+
the ith matrix is arr[i-1]*arr[i].
19+
So suppose arr = [40, 20, 30, 10, 30] means we have 4 matrices of dimensions
20+
40*20, 20*30, 30*10 and 10*30.
21+
22+
matrix_chain_multiply() returns an integer denoting minimum number of multiplications to
23+
multiply the chain.
24+
25+
We do not need to perform actual multiplication here.
26+
We only need to decide the order in which to perform the multiplication.
27+
28+
Hints:
29+
1. Number of multiplications (ie cost) to multiply 2 matrices
30+
of size m*p and p*n is m*p*n.
31+
2. Cost of matrix multiplication is associative ie (M1*M2)*M3 != M1*(M2*M3)
32+
3. Matrix multiplication is not commutative. So, M1*M2 does not mean M2*M1 can be done.
33+
4. To determine the required order, we can try different combinations.
34+
So, this problem has overlapping sub-problems and can be solved using recursion.
35+
We use Dynamic Programming for optimal time complexity.
36+
37+
Example input:
38+
arr = [40, 20, 30, 10, 30]
39+
output: 26000
40+
"""
41+
from collections.abc import Iterator
42+
from contextlib import contextmanager
43+
from functools import cache
44+
from sys import maxsize
45+
46+
47+
def matrix_chain_multiply(arr: list[int]) -> int:
48+
"""
49+
Find the minimum number of multiplcations required to multiply the chain of matrices
50+
51+
Args:
52+
arr: The input array of integers.
53+
54+
Returns:
55+
Minimum number of multiplications needed to multiply the chain
56+
57+
Examples:
58+
>>> matrix_chain_multiply([1, 2, 3, 4, 3])
59+
30
60+
>>> matrix_chain_multiply([10])
61+
0
62+
>>> matrix_chain_multiply([10, 20])
63+
0
64+
>>> matrix_chain_multiply([19, 2, 19])
65+
722
66+
>>> matrix_chain_multiply(list(range(1, 100)))
67+
323398
68+
69+
# >>> matrix_chain_multiply(list(range(1, 251)))
70+
# 2626798
71+
"""
72+
if len(arr) < 2:
73+
return 0
74+
# initialising 2D dp matrix
75+
n = len(arr)
76+
dp = [[maxsize for j in range(n)] for i in range(n)]
77+
# we want minimum cost of multiplication of matrices
78+
# of dimension (i*k) and (k*j). This cost is arr[i-1]*arr[k]*arr[j].
79+
for i in range(n - 1, 0, -1):
80+
for j in range(i, n):
81+
if i == j:
82+
dp[i][j] = 0
83+
continue
84+
for k in range(i, j):
85+
dp[i][j] = min(
86+
dp[i][j], dp[i][k] + dp[k + 1][j] + arr[i - 1] * arr[k] * arr[j]
87+
)
88+
89+
return dp[1][n - 1]
90+
91+
92+
def matrix_chain_order(dims: list[int]) -> int:
93+
"""
94+
Source: https://en.wikipedia.org/wiki/Matrix_chain_multiplication
95+
The dynamic programming solution is faster than cached the recursive solution and
96+
can handle larger inputs.
97+
>>> matrix_chain_order([1, 2, 3, 4, 3])
98+
30
99+
>>> matrix_chain_order([10])
100+
0
101+
>>> matrix_chain_order([10, 20])
102+
0
103+
>>> matrix_chain_order([19, 2, 19])
104+
722
105+
>>> matrix_chain_order(list(range(1, 100)))
106+
323398
107+
108+
# >>> matrix_chain_order(list(range(1, 251))) # Max before RecursionError is raised
109+
# 2626798
110+
"""
111+
112+
@cache
113+
def a(i: int, j: int) -> int:
114+
return min(
115+
(a(i, k) + dims[i] * dims[k] * dims[j] + a(k, j) for k in range(i + 1, j)),
116+
default=0,
117+
)
118+
119+
return a(0, len(dims) - 1)
120+
121+
122+
@contextmanager
123+
def elapsed_time(msg: str) -> Iterator:
124+
# print(f"Starting: {msg}")
125+
from time import perf_counter_ns
126+
127+
start = perf_counter_ns()
128+
yield
129+
print(f"Finished: {msg} in {(perf_counter_ns() - start) / 10 ** 9} seconds.")
130+
131+
132+
if __name__ == "__main__":
133+
import doctest
134+
135+
doctest.testmod()
136+
with elapsed_time("matrix_chain_order"):
137+
print(f"{matrix_chain_order(list(range(1, 251))) = }")
138+
with elapsed_time("matrix_chain_multiply"):
139+
print(f"{matrix_chain_multiply(list(range(1, 251))) = }")
140+
with elapsed_time("matrix_chain_order"):
141+
print(f"{matrix_chain_order(list(range(1, 251))) = }")
142+
with elapsed_time("matrix_chain_multiply"):
143+
print(f"{matrix_chain_multiply(list(range(1, 251))) = }")

0 commit comments

Comments
 (0)