Skip to content

Commit 65e3264

Browse files
authored
refactor: InverseOfMatrix (#5446)
refactor: InverseOfMatrix
1 parent bded78f commit 65e3264

File tree

2 files changed

+52
-50
lines changed

2 files changed

+52
-50
lines changed

src/main/java/com/thealgorithms/misc/InverseOfMatrix.java

+24-50
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,29 @@
11
package com.thealgorithms.misc;
22

3-
import java.util.Scanner;
4-
5-
/*
6-
* Wikipedia link : https://en.wikipedia.org/wiki/Invertible_matrix
7-
*
8-
* Here we use gauss elimination method to find the inverse of a given matrix.
9-
* To understand gauss elimination method to find inverse of a matrix:
10-
* https://www.sangakoo.com/en/unit/inverse-matrix-method-of-gaussian-elimination
11-
*
12-
* We can also find the inverse of a matrix
3+
/**
4+
* This class provides methods to compute the inverse of a square matrix
5+
* using Gaussian elimination. For more details, refer to:
6+
* https://en.wikipedia.org/wiki/Invertible_matrix
137
*/
148
public final class InverseOfMatrix {
159
private InverseOfMatrix() {
1610
}
1711

18-
public static void main(String[] argv) {
19-
Scanner input = new Scanner(System.in);
20-
System.out.println("Enter the matrix size (Square matrix only): ");
21-
int n = input.nextInt();
22-
double[][] a = new double[n][n];
23-
System.out.println("Enter the elements of matrix: ");
24-
for (int i = 0; i < n; i++) {
25-
for (int j = 0; j < n; j++) {
26-
a[i][j] = input.nextDouble();
27-
}
28-
}
29-
30-
double[][] d = invert(a);
31-
System.out.println();
32-
System.out.println("The inverse is: ");
33-
for (int i = 0; i < n; ++i) {
34-
for (int j = 0; j < n; ++j) {
35-
System.out.print(d[i][j] + " ");
36-
}
37-
System.out.println();
38-
}
39-
input.close();
40-
}
41-
4212
public static double[][] invert(double[][] a) {
4313
int n = a.length;
4414
double[][] x = new double[n][n];
4515
double[][] b = new double[n][n];
4616
int[] index = new int[n];
17+
18+
// Initialize the identity matrix
4719
for (int i = 0; i < n; ++i) {
4820
b[i][i] = 1;
4921
}
5022

51-
// Transform the matrix into an upper triangle
23+
// Perform Gaussian elimination
5224
gaussian(a, index);
5325

54-
// Update the matrix b[i][j] with the ratios stored
26+
// Update matrix b with the ratios stored during elimination
5527
for (int i = 0; i < n - 1; ++i) {
5628
for (int j = i + 1; j < n; ++j) {
5729
for (int k = 0; k < n; ++k) {
@@ -60,7 +32,7 @@ public static double[][] invert(double[][] a) {
6032
}
6133
}
6234

63-
// Perform backward substitutions
35+
// Perform backward substitution to find the inverse
6436
for (int i = 0; i < n; ++i) {
6537
x[n - 1][i] = b[index[n - 1]][i] / a[index[n - 1]][n - 1];
6638
for (int j = n - 2; j >= 0; --j) {
@@ -73,19 +45,20 @@ public static double[][] invert(double[][] a) {
7345
}
7446
return x;
7547
}
76-
77-
// Method to carry out the partial-pivoting Gaussian
78-
// elimination. Here index[] stores pivoting order.
79-
public static void gaussian(double[][] a, int[] index) {
48+
/**
49+
* Method to carry out the partial-pivoting Gaussian
50+
* elimination. Here index[] stores pivoting order.
51+
**/
52+
private static void gaussian(double[][] a, int[] index) {
8053
int n = index.length;
8154
double[] c = new double[n];
8255

83-
// Initialize the index
56+
// Initialize the index array
8457
for (int i = 0; i < n; ++i) {
8558
index[i] = i;
8659
}
8760

88-
// Find the rescaling factors, one from each row
61+
// Find the rescaling factors for each row
8962
for (int i = 0; i < n; ++i) {
9063
double c1 = 0;
9164
for (int j = 0; j < n; ++j) {
@@ -97,22 +70,23 @@ public static void gaussian(double[][] a, int[] index) {
9770
c[i] = c1;
9871
}
9972

100-
// Search the pivoting element from each column
101-
int k = 0;
73+
// Perform pivoting
10274
for (int j = 0; j < n - 1; ++j) {
10375
double pi1 = 0;
76+
int k = j;
10477
for (int i = j; i < n; ++i) {
105-
double pi0 = Math.abs(a[index[i]][j]);
106-
pi0 /= c[index[i]];
78+
double pi0 = Math.abs(a[index[i]][j]) / c[index[i]];
10779
if (pi0 > pi1) {
10880
pi1 = pi0;
10981
k = i;
11082
}
11183
}
112-
// Interchange rows according to the pivoting order
113-
int itmp = index[j];
84+
85+
// Swap rows
86+
int temp = index[j];
11487
index[j] = index[k];
115-
index[k] = itmp;
88+
index[k] = temp;
89+
11690
for (int i = j + 1; i < n; ++i) {
11791
double pj = a[index[i]][j] / a[index[j]][j];
11892

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.thealgorithms.misc;
2+
3+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
4+
5+
import java.util.stream.Stream;
6+
import org.junit.jupiter.params.ParameterizedTest;
7+
import org.junit.jupiter.params.provider.Arguments;
8+
import org.junit.jupiter.params.provider.MethodSource;
9+
10+
class InverseOfMatrixTest {
11+
12+
@ParameterizedTest
13+
@MethodSource("provideTestCases")
14+
void testInvert(double[][] matrix, double[][] expectedInverse) {
15+
double[][] result = InverseOfMatrix.invert(matrix);
16+
assertMatrixEquals(expectedInverse, result);
17+
}
18+
19+
private static Stream<Arguments> provideTestCases() {
20+
return Stream.of(Arguments.of(new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}, new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}), Arguments.of(new double[][] {{4, 7}, {2, 6}}, new double[][] {{0.6, -0.7}, {-0.2, 0.4}}));
21+
}
22+
23+
private void assertMatrixEquals(double[][] expected, double[][] actual) {
24+
for (int i = 0; i < expected.length; i++) {
25+
assertArrayEquals(expected[i], actual[i], 1.0E-10, "Row " + i + " is not equal");
26+
}
27+
}
28+
}

0 commit comments

Comments
 (0)