Skip to content

Commit 3e77e92

Browse files
committed
ref: refactor SplayTree implementation
- Implement 3 traversal class of the `TreeTraversal` interface - Addjust tests to better cover code.
1 parent 6b3816b commit 3e77e92

File tree

2 files changed

+81
-103
lines changed

2 files changed

+81
-103
lines changed

src/main/java/com/thealgorithms/datastructures/trees/SplayTree.java

Lines changed: 60 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
public class SplayTree {
2929

3030
private static class Node {
31-
int key;
31+
final int key;
3232
Node left;
3333
Node right;
3434

@@ -41,14 +41,11 @@ private static class Node {
4141

4242
private Node root;
4343

44-
/**
45-
* Constructs an empty SplayTree.
46-
*/
4744
public SplayTree() {
4845
root = null;
4946
}
5047

51-
/**
48+
/**
5249
* Checks if the tree is empty.
5350
*
5451
* @return True if the tree is empty, otherwise false.
@@ -135,29 +132,26 @@ private Node splay(Node root, int key) {
135132
root = rotateRight(root);
136133
} // Zig-Zag case
137134
else if (root.left.key < key) {
138-
// Recursive call to splay on grandchild
139135
root.left.right = splay(root.left.right, key);
140136
// Perform zag operation on parent
141-
if (root.left.right != null) root.left = rotateLeft(root.left);
137+
if (root.left.right != null) {
138+
root.left = rotateLeft(root.left);
139+
}
142140
}
143-
144141
return (root.left == null) ? root : rotateRight(root);
145142
} else {
146143
if (root.right == null) return root;
147144
// Zag-Zag case
148145
if (root.right.key > key) {
149-
// Recursive call to splay on grandchild
150146
root.right.left = splay(root.right.left, key);
151147
// Perform zig operation on parent
152148
if (root.right.left != null) root.right = rotateRight(root.right);
153149
} // Zag-Zig case
154150
else if (root.right.key < key) {
155-
// Recursive call to splay on grandchild
156151
root.right.right = splay(root.right.right, key);
157152
// Perform zag operation on parent
158153
root = rotateLeft(root);
159154
}
160-
161155
return (root.right == null) ? root : rotateLeft(root);
162156
}
163157
}
@@ -228,8 +222,8 @@ public void delete(int key) {
228222
if (root.left == null) {
229223
root = root.right;
230224
} else {
231-
Node temp = root;
232225
// Splay to bring the largest key in left subtree to root
226+
Node temp = root;
233227
root = splay(root.left, key);
234228
root.right = temp.right;
235229
}
@@ -238,71 +232,73 @@ public void delete(int key) {
238232
/**
239233
* Perform a traversal of the SplayTree.
240234
*
241-
* @param traverseOrder The order of traversal (IN_ORDER, PRE_ORDER, or POST_ORDER).
235+
* @param traversal The type of traversal method.
242236
* @return A list containing the keys in the specified traversal order.
243237
*/
244-
public List<Integer> traverse(TraverseOrder traverseOrder) {
238+
public List<Integer> traverse(TreeTraversal traversal) {
245239
List<Integer> result = new LinkedList<>();
246-
switch (traverseOrder) {
247-
case IN_ORDER:
248-
inOrderRec(root, result);
249-
break;
250-
case PRE_ORDER:
251-
preOrderRec(root, result);
252-
break;
253-
case POST_ORDER:
254-
postOrderRec(root, result);
255-
break;
256-
default:
257-
throw new IllegalArgumentException("Invalid traversal order: " + traverseOrder);
258-
}
240+
traversal.traverse(root, result);
259241
return result;
260242
}
261243

262-
/**
263-
* Recursive function for in-order traversal.
264-
*
265-
* @param root The root of the subtree to traverse.
266-
* @param result The list to store the traversal result.
267-
*/
268-
private void inOrderRec(Node root, List<Integer> result) {
269-
if (root != null) {
270-
inOrderRec(root.left, result);
271-
result.add(root.key);
272-
inOrderRec(root.right, result);
244+
public interface TreeTraversal {
245+
/**
246+
* Recursive function for a specific order traversal.
247+
*
248+
* @param root The root of the subtree to traverse.
249+
* @param result The list to store the traversal result.
250+
*/
251+
void traverse(Node root, List<Integer> result);
252+
}
253+
254+
private static final class InOrderTraversal implements TreeTraversal {
255+
private InOrderTraversal() {
256+
}
257+
258+
public void traverse(Node root, List<Integer> result) {
259+
if (root != null) {
260+
traverse(root.left, result);
261+
result.add(root.key);
262+
traverse(root.right, result);
263+
}
273264
}
274265
}
275266

276-
/**
277-
* Recursive function for pre-order traversal.
278-
*
279-
* @param root The root of the subtree to traverse.
280-
* @param result The list to store the traversal result.
281-
*/
282-
private void preOrderRec(Node root, List<Integer> result) {
283-
if (root != null) {
284-
result.add(root.key);
285-
preOrderRec(root.left, result);
286-
preOrderRec(root.right, result);
267+
private static final class PreOrderTraversal implements TreeTraversal {
268+
private PreOrderTraversal() {
269+
}
270+
271+
public void traverse(Node root, List<Integer> result) {
272+
if (root != null) {
273+
result.add(root.key);
274+
traverse(root.left, result);
275+
traverse(root.right, result);
276+
}
287277
}
288278
}
289279

290-
/**
291-
* Recursive function for post-order traversal.
292-
*
293-
* @param root The root of the subtree to traverse.
294-
* @param result The list to store the traversal result.
295-
*/
296-
private void postOrderRec(Node root, List<Integer> result) {
297-
if (root != null) {
298-
postOrderRec(root.left, result);
299-
postOrderRec(root.right, result);
300-
result.add(root.key);
280+
private static final class PostOrderTraversal implements TreeTraversal {
281+
private PostOrderTraversal() {
282+
}
283+
284+
public void traverse(Node root, List<Integer> result) {
285+
if (root != null) {
286+
traverse(root.left, result);
287+
traverse(root.right, result);
288+
result.add(root.key);
289+
}
301290
}
302291
}
303292

304-
/**
305-
* Enum to specify the order of traversal.
306-
*/
307-
public enum TraverseOrder { IN_ORDER, PRE_ORDER, POST_ORDER, INVALID }
293+
public static TreeTraversal getInOrderTraversal() {
294+
return new InOrderTraversal();
295+
}
296+
297+
public static TreeTraversal getPreOrderTraversal() {
298+
return new PreOrderTraversal();
299+
}
300+
301+
public static TreeTraversal getPostOrderTraversal() {
302+
return new PostOrderTraversal();
303+
}
308304
}

src/test/java/com/thealgorithms/datastructures/trees/SplayTreeTest.java

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,18 @@
66
import static org.junit.jupiter.api.Assertions.assertTrue;
77

88
import java.util.Arrays;
9-
import java.util.LinkedList;
109
import java.util.List;
1110
import java.util.stream.Stream;
12-
import org.junit.jupiter.api.Test;
1311
import org.junit.jupiter.params.ParameterizedTest;
1412
import org.junit.jupiter.params.provider.MethodSource;
1513

1614
public class SplayTreeTest {
15+
1716
@ParameterizedTest
18-
@MethodSource("traversalOrders")
19-
public void testTraversal(SplayTree.TraverseOrder traverseOrder) {
17+
@MethodSource("traversalStrategies")
18+
public void testTraversal(SplayTree.TreeTraversal traversal, List<Integer> expected) {
2019
SplayTree tree = createComplexTree();
21-
List<Integer> expected = getExpectedTraversalResult(traverseOrder);
22-
List<Integer> result = tree.traverse(traverseOrder);
23-
20+
List<Integer> result = tree.traverse(traversal);
2421
assertEquals(expected, result);
2522
}
2623

@@ -35,7 +32,6 @@ public void testSearch(int value) {
3532
@MethodSource("valuesToTest")
3633
public void testDelete(int value) {
3734
SplayTree tree = createComplexTree();
38-
3935
assertTrue(tree.search(value));
4036
tree.delete(value);
4137
assertFalse(tree.search(value));
@@ -52,7 +48,6 @@ public void testSearchNonExistent(int value) {
5248
@MethodSource("nonExistentValues")
5349
public void testDeleteNonExistent(int value) {
5450
SplayTree tree = createComplexTree();
55-
5651
tree.delete(value);
5752
assertFalse(tree.search(value));
5853
}
@@ -71,55 +66,42 @@ public void testInsertThrowsExceptionForDuplicateKeys(int value) {
7166
assertThrows(IllegalArgumentException.class, () -> tree.insert(value));
7267
}
7368

74-
@Test
75-
public void testInvalidTraversalOrderExceptionMessage() {
76-
SplayTree tree = createComplexTree();
77-
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> tree.traverse(SplayTree.TraverseOrder.INVALID));
78-
assertEquals("Invalid traversal order: INVALID", exception.getMessage());
69+
@ParameterizedTest
70+
@MethodSource("valuesToTest")
71+
public void testSearchInEmptyTree(int value) {
72+
SplayTree tree = new SplayTree();
73+
assertFalse(tree.search(value));
7974
}
8075

81-
private static Stream<SplayTree.TraverseOrder> traversalOrders() {
82-
return Stream.of(SplayTree.TraverseOrder.IN_ORDER, SplayTree.TraverseOrder.PRE_ORDER, SplayTree.TraverseOrder.POST_ORDER);
76+
private static Stream<Object[]> traversalStrategies() {
77+
return Stream.of(new Object[] {SplayTree.getInOrderTraversal(), Arrays.asList(1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 15, 16, 17, 18)}, new Object[] {SplayTree.getPreOrderTraversal(), Arrays.asList(18, 17, 16, 15, 13, 11, 10, 8, 7, 6, 2, 1, 5, 12)},
78+
new Object[] {SplayTree.getPostOrderTraversal(), Arrays.asList(1, 5, 2, 6, 7, 8, 10, 12, 11, 13, 15, 16, 17, 18)});
8379
}
8480

8581
private static Stream<Integer> valuesToTest() {
8682
return Stream.of(1, 5, 10);
8783
}
8884

8985
private static Stream<Integer> nonExistentValues() {
90-
return Stream.of(0, 11, 15);
91-
}
92-
93-
private List<Integer> getExpectedTraversalResult(SplayTree.TraverseOrder traverseOrder) {
94-
List<Integer> expected = new LinkedList<>();
95-
switch (traverseOrder) {
96-
case IN_ORDER:
97-
expected.addAll(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
98-
break;
99-
case PRE_ORDER:
100-
expected.addAll(Arrays.asList(10, 9, 8, 7, 3, 1, 2, 5, 4, 6));
101-
break;
102-
case POST_ORDER:
103-
expected.addAll(Arrays.asList(2, 1, 4, 6, 5, 3, 7, 8, 9, 10));
104-
break;
105-
default:
106-
throw new IllegalArgumentException("Invalid traversal order: " + traverseOrder);
107-
}
108-
return expected;
86+
return Stream.of(0, 21, 20);
10987
}
11088

11189
private SplayTree createComplexTree() {
11290
SplayTree tree = new SplayTree();
91+
tree.insert(10);
11392
tree.insert(5);
93+
tree.insert(15);
11494
tree.insert(2);
11595
tree.insert(7);
11696
tree.insert(1);
117-
tree.insert(4);
11897
tree.insert(6);
119-
tree.insert(9);
120-
tree.insert(3);
12198
tree.insert(8);
122-
tree.insert(10);
99+
tree.insert(12);
100+
tree.insert(17);
101+
tree.insert(11);
102+
tree.insert(13);
103+
tree.insert(16);
104+
tree.insert(18);
123105
return tree;
124106
}
125107
}

0 commit comments

Comments
 (0)