Skip to content

Commit feac90b

Browse files
committed
Add tests, remove main in MonteCarloTreeSearch
1 parent b54cc21 commit feac90b

File tree

2 files changed

+127
-6
lines changed

2 files changed

+127
-6
lines changed

src/main/java/com/thealgorithms/searches/MonteCarloTreeSearch.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ public Node(Node parent, boolean isPlayersTurn) {
3939
static final int WIN_SCORE = 10;
4040
static final int TIME_LIMIT = 500; // Time the algorithm will be running for (in milliseconds).
4141

42-
public static void main(String[] args) {
43-
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
44-
45-
mcts.monteCarloTreeSearch(mcts.new Node(null, true));
46-
}
47-
4842
/**
4943
* Explores a game tree using Monte Carlo Tree Search (MCTS) and returns the
5044
* most promising node.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package com.thealgorithms.searches;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertNotNull;
5+
import static org.junit.jupiter.api.Assertions.assertNull;
6+
import static org.junit.jupiter.api.Assertions.assertTrue;
7+
8+
import org.junit.jupiter.api.Test;
9+
10+
class MonteCarloTreeSearchTest {
11+
12+
/**
13+
* Test the creation of a node and its initial state.
14+
*/
15+
@Test
16+
void testNodeCreation() {
17+
MonteCarloTreeSearch.Node node = new MonteCarloTreeSearch().new Node(null, true);
18+
assertNotNull(node, "Node should be created");
19+
assertNull(node.parent, "Parent should be null for root node");
20+
assertTrue(node.childNodes.isEmpty(), "Child nodes should be empty upon creation");
21+
assertTrue(node.isPlayersTurn, "Initial turn should be player's turn");
22+
assertEquals(0, node.score, "Initial score should be zero");
23+
assertEquals(0, node.visitCount, "Initial visit count should be zero");
24+
}
25+
26+
/**
27+
* Test adding child nodes to a parent node.
28+
*/
29+
@Test
30+
void testAddChildNodes() {
31+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
32+
MonteCarloTreeSearch.Node parentNode = mcts.new Node(null, true);
33+
34+
mcts.addChildNodes(parentNode, 5);
35+
36+
assertEquals(5, parentNode.childNodes.size(), "Parent should have 5 child nodes");
37+
for (MonteCarloTreeSearch.Node child : parentNode.childNodes) {
38+
assertEquals(false, child.isPlayersTurn, "Child node should not be player's turn");
39+
assertEquals(0, child.visitCount, "Child node visit count should be zero");
40+
}
41+
}
42+
43+
/**
44+
* Test the UCT selection of a promising node.
45+
*/
46+
@Test
47+
void testGetPromisingNode() {
48+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
49+
MonteCarloTreeSearch.Node parentNode = mcts.new Node(null, true);
50+
51+
// Create child nodes with different visit counts and scores
52+
for (int i = 0; i < 3; i++) {
53+
MonteCarloTreeSearch.Node child = mcts.new Node(parentNode, false);
54+
child.visitCount = i + 1; // Visits 1, 2, 3
55+
child.score = i * 2; // Scores 0, 2, 4
56+
parentNode.childNodes.add(child);
57+
}
58+
59+
// Get promising node
60+
MonteCarloTreeSearch.Node promisingNode = mcts.getPromisingNode(parentNode);
61+
62+
// The child with the highest UCT value should be chosen.
63+
assertNotNull(promisingNode, "Promising node should not be null");
64+
assertEquals(0, parentNode.childNodes.indexOf(promisingNode), "The first child should be the most promising");
65+
}
66+
67+
/**
68+
* Test simulation of random play and backpropagation.
69+
*/
70+
@Test
71+
void testSimulateRandomPlay() {
72+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
73+
MonteCarloTreeSearch.Node node = mcts.new Node(null, true);
74+
node.visitCount = 10; // Simulating existing visits
75+
76+
// Simulate random play
77+
mcts.simulateRandomPlay(node);
78+
79+
// Check visit count after simulation
80+
assertEquals(11, node.visitCount, "Visit count should increase after simulation");
81+
82+
// Check if score is updated correctly
83+
assertTrue(node.score >= 0 && node.score <= MonteCarloTreeSearch.WIN_SCORE, "Score should be between 0 and WIN_SCORE");
84+
}
85+
86+
/**
87+
* Test retrieving the winning node based on scores.
88+
*/
89+
@Test
90+
void testGetWinnerNode() {
91+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
92+
MonteCarloTreeSearch.Node parentNode = mcts.new Node(null, true);
93+
94+
// Create child nodes with varying scores
95+
MonteCarloTreeSearch.Node winningNode = mcts.new Node(parentNode, false);
96+
winningNode.score = 10; // Highest score
97+
parentNode.childNodes.add(winningNode);
98+
99+
MonteCarloTreeSearch.Node losingNode = mcts.new Node(parentNode, false);
100+
losingNode.score = 5;
101+
parentNode.childNodes.add(losingNode);
102+
103+
MonteCarloTreeSearch.Node anotherLosingNode = mcts.new Node(parentNode, false);
104+
anotherLosingNode.score = 3;
105+
parentNode.childNodes.add(anotherLosingNode);
106+
107+
// Get the winning node
108+
MonteCarloTreeSearch.Node winnerNode = mcts.getWinnerNode(parentNode);
109+
110+
assertEquals(winningNode, winnerNode, "Winning node should have the highest score");
111+
}
112+
113+
/**
114+
* Test the full Monte Carlo Tree Search process.
115+
*/
116+
@Test
117+
void testMonteCarloTreeSearch() {
118+
MonteCarloTreeSearch mcts = new MonteCarloTreeSearch();
119+
MonteCarloTreeSearch.Node rootNode = mcts.new Node(null, true);
120+
121+
// Execute MCTS and check the resulting node
122+
MonteCarloTreeSearch.Node optimalNode = mcts.monteCarloTreeSearch(rootNode);
123+
124+
assertNotNull(optimalNode, "MCTS should return a non-null optimal node");
125+
assertTrue(rootNode.childNodes.contains(optimalNode), "Optimal node should be a child of the root");
126+
}
127+
}

0 commit comments

Comments
 (0)