Skip to content

refactor: InverseOfMatrix #5446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 24 additions & 50 deletions src/main/java/com/thealgorithms/misc/InverseOfMatrix.java
Original file line number Diff line number Diff line change
@@ -1,57 +1,29 @@
package com.thealgorithms.misc;

import java.util.Scanner;

/*
* Wikipedia link : https://en.wikipedia.org/wiki/Invertible_matrix
*
* Here we use gauss elimination method to find the inverse of a given matrix.
* To understand gauss elimination method to find inverse of a matrix:
* https://www.sangakoo.com/en/unit/inverse-matrix-method-of-gaussian-elimination
*
* We can also find the inverse of a matrix
/**
* This class provides methods to compute the inverse of a square matrix
* using Gaussian elimination. For more details, refer to:
* https://en.wikipedia.org/wiki/Invertible_matrix
*/
public final class InverseOfMatrix {
private InverseOfMatrix() {
}

public static void main(String[] argv) {
Scanner input = new Scanner(System.in);
System.out.println("Enter the matrix size (Square matrix only): ");
int n = input.nextInt();
double[][] a = new double[n][n];
System.out.println("Enter the elements of matrix: ");
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
a[i][j] = input.nextDouble();
}
}

double[][] d = invert(a);
System.out.println();
System.out.println("The inverse is: ");
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
System.out.print(d[i][j] + " ");
}
System.out.println();
}
input.close();
}

public static double[][] invert(double[][] a) {
int n = a.length;
double[][] x = new double[n][n];
double[][] b = new double[n][n];
int[] index = new int[n];

// Initialize the identity matrix
for (int i = 0; i < n; ++i) {
b[i][i] = 1;
}

// Transform the matrix into an upper triangle
// Perform Gaussian elimination
gaussian(a, index);

// Update the matrix b[i][j] with the ratios stored
// Update matrix b with the ratios stored during elimination
for (int i = 0; i < n - 1; ++i) {
for (int j = i + 1; j < n; ++j) {
for (int k = 0; k < n; ++k) {
Expand All @@ -60,7 +32,7 @@ public static double[][] invert(double[][] a) {
}
}

// Perform backward substitutions
// Perform backward substitution to find the inverse
for (int i = 0; i < n; ++i) {
x[n - 1][i] = b[index[n - 1]][i] / a[index[n - 1]][n - 1];
for (int j = n - 2; j >= 0; --j) {
Expand All @@ -73,19 +45,20 @@ public static double[][] invert(double[][] a) {
}
return x;
}

// Method to carry out the partial-pivoting Gaussian
// elimination. Here index[] stores pivoting order.
public static void gaussian(double[][] a, int[] index) {
/**
* Method to carry out the partial-pivoting Gaussian
* elimination. Here index[] stores pivoting order.
**/
private static void gaussian(double[][] a, int[] index) {
int n = index.length;
double[] c = new double[n];

// Initialize the index
// Initialize the index array
for (int i = 0; i < n; ++i) {
index[i] = i;
}

// Find the rescaling factors, one from each row
// Find the rescaling factors for each row
for (int i = 0; i < n; ++i) {
double c1 = 0;
for (int j = 0; j < n; ++j) {
Expand All @@ -97,22 +70,23 @@ public static void gaussian(double[][] a, int[] index) {
c[i] = c1;
}

// Search the pivoting element from each column
int k = 0;
// Perform pivoting
for (int j = 0; j < n - 1; ++j) {
double pi1 = 0;
int k = j;
for (int i = j; i < n; ++i) {
double pi0 = Math.abs(a[index[i]][j]);
pi0 /= c[index[i]];
double pi0 = Math.abs(a[index[i]][j]) / c[index[i]];
if (pi0 > pi1) {
pi1 = pi0;
k = i;
}
}
// Interchange rows according to the pivoting order
int itmp = index[j];

// Swap rows
int temp = index[j];
index[j] = index[k];
index[k] = itmp;
index[k] = temp;

for (int i = j + 1; i < n; ++i) {
double pj = a[index[i]][j] / a[index[j]][j];

Expand Down
28 changes: 28 additions & 0 deletions src/test/java/com/thealgorithms/misc/InverseOfMatrixTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.thealgorithms.misc;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;

import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

class InverseOfMatrixTest {

@ParameterizedTest
@MethodSource("provideTestCases")
void testInvert(double[][] matrix, double[][] expectedInverse) {
double[][] result = InverseOfMatrix.invert(matrix);
assertMatrixEquals(expectedInverse, result);
}

private static Stream<Arguments> provideTestCases() {
return Stream.of(Arguments.of(new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}, new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}), Arguments.of(new double[][] {{4, 7}, {2, 6}}, new double[][] {{0.6, -0.7}, {-0.2, 0.4}}));
}

private void assertMatrixEquals(double[][] expected, double[][] actual) {
for (int i = 0; i < expected.length; i++) {
assertArrayEquals(expected[i], actual[i], 1.0E-10, "Row " + i + " is not equal");
}
}
}