diff --git a/src/main/java/com/thealgorithms/machinelearning/LinearRegression.java b/src/main/java/com/thealgorithms/machinelearning/LinearRegression.java new file mode 100644 index 000000000000..0f986ebd5ad7 --- /dev/null +++ b/src/main/java/com/thealgorithms/machinelearning/LinearRegression.java @@ -0,0 +1,72 @@ + +package com.thealgorithms.machinelearning; + +import java.util.ArrayList; +import java.util.List; + +/** + * Author : Gowtham Kamalasekar + * LinkedIn : https://www.linkedin.com/in/gowtham-kamalasekar/ + * + * Wiki : https://en.wikipedia.org/wiki/Linear_regression + * Linear Regression Machine Learning Algorithm is a regression algorithm. + * This programs used for computing y = mx + c + * Where m is slope and c is intercept + * We can use this too predict for a given x. + */ + +class LinearRegression { + private ArrayList dependentX = new ArrayList(); + private ArrayList independentY = new ArrayList(); + private double m; + private double c; + + /** + * @param : X (dependent variable), Y (independent variable) as ArrayList + */ + LinearRegression(ArrayList dependentX, ArrayList independentY) { + this.dependentX = dependentX; + this.independentY = independentY; + this.equate(); + } + + private double sumation(List arr) { + double sum = 0.0; + + for (int i = 0; i < arr.size(); i++) { + sum += arr.get(i); + } + + return sum; + } + + private List multiplyNumber(List arr1, List arr2) { + List temp = new ArrayList(); + for (int i = 0; i < arr1.size(); i++) { + temp.add((arr1.get(i) * arr2.get(i))); + } + return temp; + } + + private void equate() { + int n = dependentX.size(); + this.m = (n * sumation(multiplyNumber(independentY, dependentX)) - (sumation(dependentX) * sumation(independentY))); + this.m = this.m / (n * (sumation(multiplyNumber(dependentX, dependentX))) - (sumation(dependentX) * sumation(dependentX))); + + this.c = (sumation(independentY) * sumation(multiplyNumber(dependentX, dependentX)) - (sumation(dependentX) * sumation(multiplyNumber(independentY, dependentX)))); + this.c = this.c / (n * (sumation(multiplyNumber(dependentX, dependentX))) - (sumation(dependentX) * sumation(dependentX))); + } + + public double getM() { + return this.m; + } + + public double getC() { + return this.c; + } + + public double predictForX(double x) { + return (this.m * x) + this.c; + } +} + diff --git a/src/test/java/com/thealgorithms/machinelearning/LinearRegressionTest.java b/src/test/java/com/thealgorithms/machinelearning/LinearRegressionTest.java new file mode 100644 index 000000000000..c16d367b0737 --- /dev/null +++ b/src/test/java/com/thealgorithms/machinelearning/LinearRegressionTest.java @@ -0,0 +1,37 @@ +package com.thealgorithms.machinelearning; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import org.junit.jupiter.api.Test; + +class LinearRegressionTest { + + @Test + void testLinearRegression() { + ArrayList dependentX = new ArrayList<>(); + ArrayList independentY = new ArrayList<>(); + + dependentX.add(1.0); + independentY.add(2.0); + dependentX.add(2.0); + independentY.add(3.0); + dependentX.add(3.0); + independentY.add(4.0); + dependentX.add(4.0); + independentY.add(5.0); + dependentX.add(5.0); + independentY.add(6.0); + + // Create LinearRegression object + LinearRegression lr = new LinearRegression(dependentX, independentY); + + // Check the slope (m) and intercept (c) + assertEquals(1.0, lr.getM(), 0.001); + assertEquals(1.0, lr.getC(), 0.001); + + // Check prediction for X = 6 + double predictedY = lr.predictForX(6.0); + assertEquals(7.0, predictedY, 0.001); + } +}