|
1 | 1 | import numpy as np
|
2 |
| - |
| 2 | +from typing import List |
3 | 3 | from hypercube_points import hypercube_points
|
4 | 4 | from data_structures.kd_tree.build_kdtree import build_kdtree
|
5 | 5 | from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
|
6 | 6 |
|
| 7 | +def main() -> None: |
| 8 | + """ |
| 9 | + Demonstrates the use of KD-Tree by building it from random points |
| 10 | + in a 10-dimensional hypercube and performing a nearest neighbor search. |
| 11 | + """ |
| 12 | + num_points: int = 5000 |
| 13 | + cube_size: int = 10 |
| 14 | + num_dimensions: int = 10 |
7 | 15 |
|
8 |
| -num_points = 5000 |
9 |
| -cube_size = 10 |
10 |
| -num_dimensions = 10 |
| 16 | + # Generate random points within the hypercube |
| 17 | + points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions) |
| 18 | + hypercube_kdtree = build_kdtree(points.tolist()) |
11 | 19 |
|
12 |
| -points = hypercube_points(num_points, cube_size, num_dimensions) |
13 |
| -hypercube_kdtree = build_kdtree(points.tolist()) |
| 20 | + # Generate a random query point within the same space |
| 21 | + rng = np.random.default_rng() |
| 22 | + query_point: List[float] = rng.random(num_dimensions).tolist() |
14 | 23 |
|
15 |
| -rng = np.random.default_rng() |
16 |
| -query_point = rng.random(num_dimensions).tolist() |
| 24 | + # Perform nearest neighbor search |
| 25 | + nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( |
| 26 | + hypercube_kdtree, query_point |
| 27 | + ) |
17 | 28 |
|
18 |
| -nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( |
19 |
| - hypercube_kdtree, query_point |
20 |
| -) |
| 29 | + # Print the results |
| 30 | + print(f"Query point: {query_point}") |
| 31 | + print(f"Nearest point: {nearest_point}") |
| 32 | + print(f"Distance: {nearest_dist:.4f}") |
| 33 | + print(f"Nodes visited: {nodes_visited}") |
21 | 34 |
|
22 |
| -print(f"Query point: {query_point}") |
23 |
| -print(f"Nearest point: {nearest_point}") |
24 |
| -print(f"Distance: {nearest_dist:.4f}") |
25 |
| -print(f"Nodes visited: {nodes_visited}") |
| 35 | +if __name__ == "__main__": |
| 36 | + main() |
0 commit comments