1
1
from typing import Optional , List , Tuple
2
2
from .kd_node import KDNode
3
3
4
- def nearest_neighbour_search (root : Optional [KDNode ], query_point : List [float ]) -> Tuple [Optional [List [float ]], float , int ]:
4
+ def nearest_neighbour_search (
5
+ root : Optional [KDNode ],
6
+ query_point : List [float ]
7
+ ) -> Tuple [Optional [List [float ]], float , int ]:
5
8
"""
6
9
Performs a nearest neighbor search in a KD-Tree for a given query point.
7
10
@@ -20,6 +23,18 @@ def nearest_neighbour_search(root: Optional[KDNode], query_point: List[float]) -
20
23
nodes_visited : int = 0
21
24
22
25
def search (node : Optional [KDNode ], depth : int = 0 ) -> None :
26
+ """
27
+ Recursively searches the KD-Tree to find the nearest point to the query point.
28
+
29
+ Args:
30
+ node (Optional[KDNode]): The current node being examined.
31
+ depth (int): The current depth of the tree, which determines the axis to split on.
32
+
33
+ Updates:
34
+ nearest_point: The closest point found so far in the KD-Tree.
35
+ nearest_dist: The squared distance from the query point to the nearest point found.
36
+ nodes_visited: The number of nodes visited during the search.
37
+ """
23
38
nonlocal nearest_point , nearest_dist , nodes_visited
24
39
if node is None :
25
40
return
@@ -28,7 +43,9 @@ def search(node: Optional[KDNode], depth: int = 0) -> None:
28
43
29
44
# Calculate the current distance (squared distance)
30
45
current_point = node .point
31
- current_dist = sum ((query_coord - point_coord ) ** 2 for query_coord , point_coord in zip (query_point , current_point ))
46
+ current_dist = sum (
47
+ (query_coord - point_coord ) ** 2 for query_coord , point_coord in zip (query_point , current_point )
48
+ )
32
49
33
50
# Update nearest point if the current node is closer
34
51
if nearest_point is None or current_dist < nearest_dist :
@@ -49,7 +66,7 @@ def search(node: Optional[KDNode], depth: int = 0) -> None:
49
66
# Search the nearer subtree first
50
67
search (nearer_subtree , depth + 1 )
51
68
52
- # If the further subtree has a closer point
69
+ # If the further subtree has a closer point, search it
53
70
if (query_point [axis ] - current_point [axis ]) ** 2 < nearest_dist :
54
71
search (further_subtree , depth + 1 )
55
72
0 commit comments