Skip to content

Commit de4768c

Browse files
authored
Create LinearRegressionTest.java
1 parent 73ad49a commit de4768c

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.thealgorithms.machinelearning;
2+
3+
import org.junit.jupiter.api.Test;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
6+
import java.util.ArrayList;
7+
8+
class LinearRegressionTest {
9+
10+
@Test
11+
void testLinearRegression() {
12+
ArrayList<Double> dependentX = new ArrayList<>();
13+
ArrayList<Double> independentY = new ArrayList<>();
14+
15+
dependentX.add(1.0);
16+
independentY.add(2.0);
17+
dependentX.add(2.0);
18+
independentY.add(3.0);
19+
dependentX.add(3.0);
20+
independentY.add(4.0);
21+
dependentX.add(4.0);
22+
independentY.add(5.0);
23+
dependentX.add(5.0);
24+
independentY.add(6.0);
25+
26+
// Create LinearRegression object
27+
LinearRegression lr = new LinearRegression(dependentX, independentY);
28+
29+
// Check the slope (m) and intercept (c)
30+
assertEquals(1.0, lr.getM(), 0.001);
31+
assertEquals(1.0, lr.getC(), 0.001);
32+
33+
// Check prediction for X = 6
34+
double predictedY = lr.PredictForX(6.0);
35+
assertEquals(7.0, predictedY, 0.001);
36+
}
37+
}

0 commit comments

Comments
 (0)