Skip to content

Commit 81d6917

Browse files
added typehints and docstrings
1 parent 1668d73 commit 81d6917

File tree

5 files changed

+95
-28
lines changed

5 files changed

+95
-28
lines changed
+14-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
1+
from typing import List, Optional
12
from .kd_node import KDNode
23

4+
def build_kdtree(points: List[List[float]], depth: int = 0) -> Optional[KDNode]:
5+
"""
6+
Builds a KD-Tree from a set of k-dimensional points.
37
4-
def build_kdtree(points, depth=0):
8+
Args:
9+
points (List[List[float]]): A list of k-dimensional points (each point is a list of floats).
10+
depth (int): The current depth in the tree. Used to determine the splitting axis. Defaults to 0.
11+
12+
Returns:
13+
Optional[KDNode]: The root of the KD-Tree or None if the input list is empty.
14+
"""
515
if not points:
616
return None
717

8-
k = len(points[0]) # dimensionality of the points
18+
k = len(points[0]) # Dimensionality of the points
919
axis = depth % k
1020

1121
# Sort point list and choose median as pivot element
12-
points.sort(key=lambda x: x[axis])
22+
points.sort(key=lambda point: point[axis])
1323
median_idx = len(points) // 2
1424

1525
# Create node and construct subtrees
1626
return KDNode(
1727
point=points[median_idx],
1828
left=build_kdtree(points[:median_idx], depth + 1),
19-
right=build_kdtree(points[median_idx + 1 :], depth + 1),
29+
right=build_kdtree(points[median_idx + 1:], depth + 1),
2030
)
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
11
import numpy as np
2-
2+
from typing import List
33
from hypercube_points import hypercube_points
44
from data_structures.kd_tree.build_kdtree import build_kdtree
55
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
66

7+
def main() -> None:
8+
"""
9+
Demonstrates the use of KD-Tree by building it from random points
10+
in a 10-dimensional hypercube and performing a nearest neighbor search.
11+
"""
12+
num_points: int = 5000
13+
cube_size: int = 10
14+
num_dimensions: int = 10
715

8-
num_points = 5000
9-
cube_size = 10
10-
num_dimensions = 10
16+
# Generate random points within the hypercube
17+
points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions)
18+
hypercube_kdtree = build_kdtree(points.tolist())
1119

12-
points = hypercube_points(num_points, cube_size, num_dimensions)
13-
hypercube_kdtree = build_kdtree(points.tolist())
20+
# Generate a random query point within the same space
21+
rng = np.random.default_rng()
22+
query_point: List[float] = rng.random(num_dimensions).tolist()
1423

15-
rng = np.random.default_rng()
16-
query_point = rng.random(num_dimensions).tolist()
24+
# Perform nearest neighbor search
25+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
26+
hypercube_kdtree, query_point
27+
)
1728

18-
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
19-
hypercube_kdtree, query_point
20-
)
29+
# Print the results
30+
print(f"Query point: {query_point}")
31+
print(f"Nearest point: {nearest_point}")
32+
print(f"Distance: {nearest_dist:.4f}")
33+
print(f"Nodes visited: {nodes_visited}")
2134

22-
print(f"Query point: {query_point}")
23-
print(f"Nearest point: {nearest_point}")
24-
print(f"Distance: {nearest_dist:.4f}")
25-
print(f"Nodes visited: {nodes_visited}")
35+
if __name__ == "__main__":
36+
main()
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
import numpy as np
2+
from typing import Union
23

4+
def hypercube_points(num_points: int, hypercube_size: Union[int, float], num_dimensions: int) -> np.ndarray:
5+
"""
6+
Generates random points uniformly distributed within an n-dimensional hypercube.
37
4-
def hypercube_points(num_points, hypercube_size, num_dimensions):
8+
Args:
9+
num_points (int): The number of random points to generate.
10+
hypercube_size (Union[int, float]): The size of the hypercube (side length).
11+
num_dimensions (int): The number of dimensions of the hypercube.
12+
13+
Returns:
14+
np.ndarray: An array of shape (num_points, num_dimensions) with the generated points.
15+
"""
516
rng = np.random.default_rng()
617
return hypercube_size * rng.random((num_points, num_dimensions))

data_structures/kd_tree/kd_node.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
1+
from typing import List, Optional
2+
13
class KDNode:
2-
def __init__(self, point, left=None, right=None):
4+
"""
5+
Represents a node in a KD-Tree.
6+
7+
Attributes:
8+
point (List[float]): The k-dimensional point stored in this node.
9+
left (Optional[KDNode]): The left subtree of this node.
10+
right (Optional[KDNode]): The right subtree of this node.
11+
"""
12+
13+
def __init__(self, point: List[float], left: Optional['KDNode'] = None, right: Optional['KDNode'] = None) -> None:
14+
"""
15+
Initializes a KDNode with a point and optional left and right children.
16+
17+
Args:
18+
point (List[float]): The k-dimensional point to be stored in this node.
19+
left (Optional[KDNode]): The left subtree of this node. Defaults to None.
20+
right (Optional[KDNode]): The right subtree of this node. Defaults to None.
21+
"""
322
self.point = point
423
self.left = left
524
self.right = right

data_structures/kd_tree/nearest_neighbour_search.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
1-
def nearest_neighbour_search(root, query_point):
2-
nearest_point = None
3-
nearest_dist = float("inf")
4-
nodes_visited = 0
5-
6-
def search(node, depth=0):
1+
from typing import Optional, List, Tuple
2+
from .kd_node import KDNode
3+
4+
def nearest_neighbour_search(root: Optional[KDNode], query_point: List[float]) -> Tuple[Optional[List[float]], float, int]:
5+
"""
6+
Performs a nearest neighbor search in a KD-Tree for a given query point.
7+
8+
Args:
9+
root (Optional[KDNode]): The root node of the KD-Tree.
10+
query_point (List[float]): The point for which the nearest neighbor is being searched.
11+
12+
Returns:
13+
Tuple[Optional[List[float]], float, int]:
14+
- The nearest point found in the KD-Tree to the query point.
15+
- The squared distance to the nearest point.
16+
- The number of nodes visited during the search.
17+
"""
18+
nearest_point: Optional[List[float]] = None
19+
nearest_dist: float = float("inf")
20+
nodes_visited: int = 0
21+
22+
def search(node: Optional[KDNode], depth: int = 0) -> None:
723
nonlocal nearest_point, nearest_dist, nodes_visited
824
if node is None:
925
return
@@ -12,7 +28,7 @@ def search(node, depth=0):
1228

1329
# Calculate the current distance (squared distance)
1430
current_point = node.point
15-
current_dist = sum((qp - cp) ** 2 for qp, cp in zip(query_point, current_point))
31+
current_dist = sum((query_coord - point_coord) ** 2 for query_coord, point_coord in zip(query_point, current_point))
1632

1733
# Update nearest point if the current node is closer
1834
if nearest_point is None or current_dist < nearest_dist:

0 commit comments

Comments
 (0)