Skip to content

Commit b3d0297

Browse files
HardvanChiefpatwal
authored andcommitted
Add tests, remove main in MonteCarloTreeSearch (TheAlgorithms#5673)
1 parent cd89641 commit b3d0297

File tree

3 files changed

+127
-6
lines changed

3 files changed

+127
-6
lines changed

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,7 @@
10101010
* [JumpSearchTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/JumpSearchTest.java)
10111011
* [KMPSearchTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/KMPSearchTest.java)
10121012
* [LinearSearchTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/LinearSearchTest.java)
1013+
* [MonteCarloTreeSearchTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/MonteCarloTreeSearchTest.java)
10131014
* [OrderAgnosticBinarySearchTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/OrderAgnosticBinarySearchTest.java)
10141015
* [PerfectBinarySearchTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/PerfectBinarySearchTest.java)
10151016
* [QuickSelectTest](https://github.com/TheAlgorithms/Java/blob/master/src/test/java/com/thealgorithms/searches/QuickSelectTest.java)

src/main/java/com/thealgorithms/searches/MonteCarloTreeSearch.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ public Node(Node parent, boolean isPlayersTurn) {
3939
static final int WIN_SCORE = 10;
4040
static final int TIME_LIMIT = 500; // Time the algorithm will be running for (in milliseconds).
4141

42-
public static void main(String[] args) {
43-
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
44-
45-
mcts.monteCarloTreeSearch(mcts.new Node(null, true));
46-
}
47-
4842
/**
4943
* Explores a game tree using Monte Carlo Tree Search (MCTS) and returns the
5044
* most promising node.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package com.thealgorithms.searches;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertFalse;
5+
import static org.junit.jupiter.api.Assertions.assertNotNull;
6+
import static org.junit.jupiter.api.Assertions.assertTrue;
7+
8+
import org.junit.jupiter.api.Test;
9+
10+
class MonteCarloTreeSearchTest {
11+
12+
/**
13+
* Test the creation of a node and its initial state.
14+
*/
15+
@Test
16+
void testNodeCreation() {
17+
MonteCarloTreeSearch.Node node = new MonteCarloTreeSearch().new Node(null, true);
18+
assertNotNull(node, "Node should be created");
19+
assertTrue(node.childNodes.isEmpty(), "Child nodes should be empty upon creation");
20+
assertTrue(node.isPlayersTurn, "Initial turn should be player's turn");
21+
assertEquals(0, node.score, "Initial score should be zero");
22+
assertEquals(0, node.visitCount, "Initial visit count should be zero");
23+
}
24+
25+
/**
26+
* Test adding child nodes to a parent node.
27+
*/
28+
@Test
29+
void testAddChildNodes() {
30+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
31+
MonteCarloTreeSearch.Node parentNode = mcts.new Node(null, true);
32+
33+
mcts.addChildNodes(parentNode, 5);
34+
35+
assertEquals(5, parentNode.childNodes.size(), "Parent should have 5 child nodes");
36+
for (MonteCarloTreeSearch.Node child : parentNode.childNodes) {
37+
assertFalse(child.isPlayersTurn, "Child node should not be player's turn");
38+
assertEquals(0, child.visitCount, "Child node visit count should be zero");
39+
}
40+
}
41+
42+
/**
43+
* Test the UCT selection of a promising node.
44+
*/
45+
@Test
46+
void testGetPromisingNode() {
47+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
48+
MonteCarloTreeSearch.Node parentNode = mcts.new Node(null, true);
49+
50+
// Create child nodes with different visit counts and scores
51+
for (int i = 0; i < 3; i++) {
52+
MonteCarloTreeSearch.Node child = mcts.new Node(parentNode, false);
53+
child.visitCount = i + 1;
54+
child.score = i * 2;
55+
parentNode.childNodes.add(child);
56+
}
57+
58+
// Get promising node
59+
MonteCarloTreeSearch.Node promisingNode = mcts.getPromisingNode(parentNode);
60+
61+
// The child with the highest UCT value should be chosen.
62+
assertNotNull(promisingNode, "Promising node should not be null");
63+
assertEquals(0, parentNode.childNodes.indexOf(promisingNode), "The first child should be the most promising");
64+
}
65+
66+
/**
67+
* Test simulation of random play and backpropagation.
68+
*/
69+
@Test
70+
void testSimulateRandomPlay() {
71+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
72+
MonteCarloTreeSearch.Node node = mcts.new Node(null, true);
73+
node.visitCount = 10; // Simulating existing visits
74+
75+
// Simulate random play
76+
mcts.simulateRandomPlay(node);
77+
78+
// Check visit count after simulation
79+
assertEquals(11, node.visitCount, "Visit count should increase after simulation");
80+
81+
// Check if score is updated correctly
82+
assertTrue(node.score >= 0 && node.score <= MonteCarloTreeSearch.WIN_SCORE, "Score should be between 0 and WIN_SCORE");
83+
}
84+
85+
/**
86+
* Test retrieving the winning node based on scores.
87+
*/
88+
@Test
89+
void testGetWinnerNode() {
90+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
91+
MonteCarloTreeSearch.Node parentNode = mcts.new Node(null, true);
92+
93+
// Create child nodes with varying scores
94+
MonteCarloTreeSearch.Node winningNode = mcts.new Node(parentNode, false);
95+
winningNode.score = 10; // Highest score
96+
parentNode.childNodes.add(winningNode);
97+
98+
MonteCarloTreeSearch.Node losingNode = mcts.new Node(parentNode, false);
99+
losingNode.score = 5;
100+
parentNode.childNodes.add(losingNode);
101+
102+
MonteCarloTreeSearch.Node anotherLosingNode = mcts.new Node(parentNode, false);
103+
anotherLosingNode.score = 3;
104+
parentNode.childNodes.add(anotherLosingNode);
105+
106+
// Get the winning node
107+
MonteCarloTreeSearch.Node winnerNode = mcts.getWinnerNode(parentNode);
108+
109+
assertEquals(winningNode, winnerNode, "Winning node should have the highest score");
110+
}
111+
112+
/**
113+
* Test the full Monte Carlo Tree Search process.
114+
*/
115+
@Test
116+
void testMonteCarloTreeSearch() {
117+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
118+
MonteCarloTreeSearch.Node rootNode = mcts.new Node(null, true);
119+
120+
// Execute MCTS and check the resulting node
121+
MonteCarloTreeSearch.Node optimalNode = mcts.monteCarloTreeSearch(rootNode);
122+
123+
assertNotNull(optimalNode, "MCTS should return a non-null optimal node");
124+
assertTrue(rootNode.childNodes.contains(optimalNode), "Optimal node should be a child of the root");
125+
}
126+
}

0 commit comments

Comments
 (0)