Skip to content

Add tests, remove main, enhance docs in MatrixChainMultiplication #5658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DIRECTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@
* [LongestIncreasingSubsequenceTests](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestIncreasingSubsequenceTests.java)
* [LongestPalindromicSubstringTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestPalindromicSubstringTest.java)
* [LongestValidParenthesesTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/LongestValidParenthesesTest.java)
* [MatrixChainMultiplicationTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MatrixChainMultiplicationTest.java)
* [MinimumPathSumTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumPathSumTest.java)
* [MinimumSumPartitionTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/MinimumSumPartitionTest.java)
* [OptimalJobSchedulingTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/dynamicprogramming/OptimalJobSchedulingTest.java)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,32 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Scanner;

/**
* The MatrixChainMultiplication class provides functionality to compute the
* optimal way to multiply a sequence of matrices. The optimal multiplication
* order is determined using dynamic programming, which minimizes the total
* number of scalar multiplications required.
*/
public final class MatrixChainMultiplication {
private MatrixChainMultiplication() {
}

private static final Scanner SCANNER = new Scanner(System.in);
private static final ArrayList<Matrix> MATRICES = new ArrayList<>();
private static int size;
// Matrices to store minimum multiplication costs and split points
private static int[][] m;
private static int[][] s;
private static int[] p;

public static void main(String[] args) {
int count = 1;
while (true) {
String[] mSize = input("input size of matrix A(" + count + ") ( ex. 10 20 ) : ");
int col = Integer.parseInt(mSize[0]);
if (col == 0) {
break;
}
int row = Integer.parseInt(mSize[1]);

Matrix matrix = new Matrix(count, col, row);
MATRICES.add(matrix);
count++;
}
for (Matrix m : MATRICES) {
System.out.format("A(%d) = %2d x %2d%n", m.count(), m.col(), m.row());
}

size = MATRICES.size();
/**
* Calculates the optimal order for multiplying a given list of matrices.
*
* @param matrices an ArrayList of Matrix objects representing the matrices
* to be multiplied.
* @return a Result object containing the matrices of minimum costs and
* optimal splits.
*/
public static Result calculateMatrixChainOrder(ArrayList<Matrix> matrices) {
int size = matrices.size();
m = new int[size + 1][size + 1];
s = new int[size + 1][size + 1];
p = new int[size + 1];
Expand All @@ -44,51 +38,20 @@ public static void main(String[] args) {
}

for (int i = 0; i < p.length; i++) {
p[i] = i == 0 ? MATRICES.get(i).col() : MATRICES.get(i - 1).row();
p[i] = i == 0 ? matrices.get(i).col() : matrices.get(i - 1).row();
}

matrixChainOrder();
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(m);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();
printArray(s);
for (int i = 0; i < size; i++) {
System.out.print("-------");
}
System.out.println();

System.out.println("Optimal solution : " + m[1][size]);
System.out.print("Optimal parens : ");
printOptimalParens(1, size);
}

private static void printOptimalParens(int i, int j) {
if (i == j) {
System.out.print("A" + i);
} else {
System.out.print("(");
printOptimalParens(i, s[i][j]);
printOptimalParens(s[i][j] + 1, j);
System.out.print(")");
}
}

private static void printArray(int[][] array) {
for (int i = 1; i < size + 1; i++) {
for (int j = 1; j < size + 1; j++) {
System.out.printf("%7d", array[i][j]);
}
System.out.println();
}
matrixChainOrder(size);
return new Result(m, s);
}

private static void matrixChainOrder() {
/**
* A helper method that computes the minimum cost of multiplying
* the matrices using dynamic programming.
*
* @param size the number of matrices in the multiplication sequence.
*/
private static void matrixChainOrder(int size) {
for (int i = 1; i < size + 1; i++) {
m[i][i] = 0;
}
Expand All @@ -109,33 +72,92 @@ private static void matrixChainOrder() {
}
}

private static String[] input(String string) {
System.out.print(string);
return (SCANNER.nextLine().split(" "));
}
}

class Matrix {
/**
* The Result class holds the results of the matrix chain multiplication
* calculation, including the matrix of minimum costs and split points.
*/
public static class Result {
private final int[][] m;
private final int[][] s;

/**
* Constructs a Result object with the specified matrices of minimum
* costs and split points.
*
* @param m the matrix of minimum multiplication costs.
* @param s the matrix of optimal split points.
*/
public Result(int[][] m, int[][] s) {
this.m = m;
this.s = s;
}

private final int count;
private final int col;
private final int row;
/**
* Returns the matrix of minimum multiplication costs.
*
* @return the matrix of minimum multiplication costs.
*/
public int[][] getM() {
return m;
}

Matrix(int count, int col, int row) {
this.count = count;
this.col = col;
this.row = row;
/**
* Returns the matrix of optimal split points.
*
* @return the matrix of optimal split points.
*/
public int[][] getS() {
return s;
}
}

int count() {
return count;
}
/**
* The Matrix class represents a matrix with its dimensions and count.
*/
public static class Matrix {
private final int count;
private final int col;
private final int row;

/**
* Constructs a Matrix object with the specified count, number of columns,
* and number of rows.
*
* @param count the identifier for the matrix.
* @param col the number of columns in the matrix.
* @param row the number of rows in the matrix.
*/
public Matrix(int count, int col, int row) {
this.count = count;
this.col = col;
this.row = row;
}

int col() {
return col;
}
/**
* Returns the identifier of the matrix.
*
* @return the identifier of the matrix.
*/
public int count() {
return count;
}

int row() {
return row;
/**
* Returns the number of columns in the matrix.
*
* @return the number of columns in the matrix.
*/
public int col() {
return col;
}

/**
* Returns the number of rows in the matrix.
*
* @return the number of rows in the matrix.
*/
public int row() {
return row;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.thealgorithms.dynamicprogramming;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.ArrayList;
import org.junit.jupiter.api.Test;

class MatrixChainMultiplicationTest {

@Test
void testMatrixCreation() {
MatrixChainMultiplication.Matrix matrix1 = new MatrixChainMultiplication.Matrix(1, 10, 20);
MatrixChainMultiplication.Matrix matrix2 = new MatrixChainMultiplication.Matrix(2, 20, 30);

assertEquals(1, matrix1.count());
assertEquals(10, matrix1.col());
assertEquals(20, matrix1.row());

assertEquals(2, matrix2.count());
assertEquals(20, matrix2.col());
assertEquals(30, matrix2.row());
}

@Test
void testMatrixChainOrder() {
// Create a list of matrices to be multiplied
ArrayList<MatrixChainMultiplication.Matrix> matrices = new ArrayList<>();
matrices.add(new MatrixChainMultiplication.Matrix(1, 10, 20)); // A(1) = 10 x 20
matrices.add(new MatrixChainMultiplication.Matrix(2, 20, 30)); // A(2) = 20 x 30

// Calculate matrix chain order
MatrixChainMultiplication.Result result = MatrixChainMultiplication.calculateMatrixChainOrder(matrices);

// Expected cost of multiplying A(1) and A(2)
int expectedCost = 6000; // The expected optimal cost of multiplying A(1)(10x20) and A(2)(20x30)
int actualCost = result.getM()[1][2];

assertEquals(expectedCost, actualCost);
}

@Test
void testOptimalParentheses() {
// Create a list of matrices to be multiplied
ArrayList<MatrixChainMultiplication.Matrix> matrices = new ArrayList<>();
matrices.add(new MatrixChainMultiplication.Matrix(1, 10, 20)); // A(1) = 10 x 20
matrices.add(new MatrixChainMultiplication.Matrix(2, 20, 30)); // A(2) = 20 x 30

// Calculate matrix chain order
MatrixChainMultiplication.Result result = MatrixChainMultiplication.calculateMatrixChainOrder(matrices);

// Check the optimal split for parentheses
assertEquals(1, result.getS()[1][2]); // s[1][2] should point to the optimal split
}
}