|
1 |
| -import unittest |
2 | 1 | import numpy as np
|
3 | 2 | from data_structures.kd_tree.build_kdtree import build_kdtree
|
4 | 3 | from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
|
5 | 4 | from data_structures.kd_tree.kd_node import KDNode
|
6 | 5 | from data_structures.kd_tree.example.hypercube_points import hypercube_points
|
7 | 6 |
|
| 7 | +def test_build_kdtree(): |
| 8 | + """ |
| 9 | + Test that KD-Tree is built correctly. |
| 10 | + """ |
| 11 | + num_points = 10 |
| 12 | + cube_size = 10.0 |
| 13 | + num_dimensions = 2 |
| 14 | + points = hypercube_points(num_points, cube_size, num_dimensions) |
| 15 | + kdtree = build_kdtree(points.tolist()) |
8 | 16 |
|
9 |
| -class TestKDTree(unittest.TestCase): |
10 |
| - def setUp(self): |
11 |
| - """ |
12 |
| - Set up test data. |
13 |
| - """ |
14 |
| - self.num_points = 10 |
15 |
| - self.cube_size = 10.0 |
16 |
| - self.num_dimensions = 2 |
17 |
| - self.points = hypercube_points( |
18 |
| - self.num_points, self.cube_size, self.num_dimensions |
19 |
| - ) |
20 |
| - self.kdtree = build_kdtree(self.points.tolist()) |
| 17 | + # Check if root is not None |
| 18 | + assert kdtree is not None |
21 | 19 |
|
22 |
| - def test_build_kdtree(self): |
23 |
| - """ |
24 |
| - Test that KD-Tree is built correctly. |
25 |
| - """ |
26 |
| - # Check if root is not None |
27 |
| - self.assertIsNotNone(self.kdtree) |
| 20 | + # Check if root has correct dimensions |
| 21 | + assert len(kdtree.point) == num_dimensions |
28 | 22 |
|
29 |
| - # Check if root has correct dimensions |
30 |
| - self.assertEqual(len(self.kdtree.point), self.num_dimensions) |
| 23 | + # Check that the tree is balanced to some extent (simplistic check) |
| 24 | + assert isinstance(kdtree, KDNode) |
31 | 25 |
|
32 |
| - # Check that the tree is balanced to some extent (simplistic check) |
33 |
| - self.assertIsInstance(self.kdtree, KDNode) |
| 26 | +def test_nearest_neighbour_search(): |
| 27 | + """ |
| 28 | + Test the nearest neighbor search function. |
| 29 | + """ |
| 30 | + num_points = 10 |
| 31 | + cube_size = 10.0 |
| 32 | + num_dimensions = 2 |
| 33 | + points = hypercube_points(num_points, cube_size, num_dimensions) |
| 34 | + kdtree = build_kdtree(points.tolist()) |
34 | 35 |
|
35 |
| - def test_nearest_neighbour_search(self): |
36 |
| - """ |
37 |
| - Test the nearest neighbor search function. |
38 |
| - """ |
39 |
| - rng = np.random.default_rng() |
40 |
| - query_point = rng.random(self.num_dimensions).tolist() |
| 36 | + rng = np.random.default_rng() |
| 37 | + query_point = rng.random(num_dimensions).tolist() |
41 | 38 |
|
42 |
| - nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( |
43 |
| - self.kdtree, query_point |
44 |
| - ) |
| 39 | + nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( |
| 40 | + kdtree, query_point |
| 41 | + ) |
45 | 42 |
|
46 |
| - # Check that nearest point is not None |
47 |
| - self.assertIsNotNone(nearest_point) |
| 43 | + # Check that nearest point is not None |
| 44 | + assert nearest_point is not None |
48 | 45 |
|
49 |
| - # Check that distance is a non-negative number |
50 |
| - self.assertGreaterEqual(nearest_dist, 0) |
| 46 | + # Check that distance is a non-negative number |
| 47 | + assert nearest_dist >= 0 |
51 | 48 |
|
52 |
| - # Check that nodes visited is a non-negative integer |
53 |
| - self.assertGreaterEqual(nodes_visited, 0) |
| 49 | + # Check that nodes visited is a non-negative integer |
| 50 | + assert nodes_visited >= 0 |
54 | 51 |
|
55 |
| - def test_edge_cases(self): |
56 |
| - """ |
57 |
| - Test edge cases such as an empty KD-Tree. |
58 |
| - """ |
59 |
| - empty_kdtree = build_kdtree([]) |
60 |
| - query_point = [0.0] * self.num_dimensions |
| 52 | +def test_edge_cases(): |
| 53 | + """ |
| 54 | + Test edge cases such as an empty KD-Tree. |
| 55 | + """ |
| 56 | + empty_kdtree = build_kdtree([]) |
| 57 | + query_point = [0.0] * 2 # Using a default 2D query point |
61 | 58 |
|
62 |
| - nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( |
63 |
| - empty_kdtree, query_point |
64 |
| - ) |
65 |
| - |
66 |
| - # With an empty KD-Tree, nearest_point should be None |
67 |
| - self.assertIsNone(nearest_point) |
68 |
| - self.assertEqual(nearest_dist, float("inf")) |
69 |
| - self.assertEqual(nodes_visited, 0) |
| 59 | + nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( |
| 60 | + empty_kdtree, query_point |
| 61 | + ) |
70 | 62 |
|
| 63 | + # With an empty KD-Tree, nearest_point should be None |
| 64 | + assert nearest_point is None |
| 65 | + assert nearest_dist == float("inf") |
| 66 | + assert nodes_visited == 0 |
71 | 67 |
|
72 | 68 | if __name__ == "__main__":
|
73 |
| - unittest.main() |
| 69 | + import pytest |
| 70 | + pytest.main() |
0 commit comments