Skip to content

Commit 543584c

Browse files
Added tests. Updated docstrings/typehints
1 parent ead2838 commit 543584c

File tree

7 files changed

+106
-60
lines changed

7 files changed

+106
-60
lines changed
+8-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from typing import List, Optional
1+
from typing import Optional
22
from .kd_node import KDNode
33

4-
5-
def build_kdtree(points: List[List[float]], depth: int = 0) -> Optional[KDNode]:
4+
def build_kdtree(points: list[list[float]], depth: int = 0) -> Optional[KDNode]:
65
"""
7-
Builds a KD-Tree from a set of k-dimensional points.
6+
Builds a KD-Tree from a list of points.
87
98
Args:
10-
points (List[List[float]]): A list of k-dimensional points (each point is a list of floats).
11-
depth (int): The current depth in the tree. Used to determine the splitting axis. Defaults to 0.
9+
points (list[list[float]]): The list of points to build the KD-Tree from.
10+
depth (int): The current depth in the tree (used to determine axis for splitting).
1211
1312
Returns:
14-
Optional[KDNode]: The root of the KD-Tree or None if the input list is empty.
13+
Optional[KDNode]: The root node of the KD-Tree.
1514
"""
1615
if not points:
1716
return None
@@ -27,5 +26,5 @@ def build_kdtree(points: List[List[float]], depth: int = 0) -> Optional[KDNode]:
2726
return KDNode(
2827
point=points[median_idx],
2928
left=build_kdtree(points[:median_idx], depth + 1),
30-
right=build_kdtree(points[median_idx + 1 :], depth + 1),
31-
)
29+
right=build_kdtree(points[median_idx + 1:], depth + 1),
30+
)

data_structures/kd_tree/example/example_usage.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from typing import List
32
from hypercube_points import hypercube_points
43
from data_structures.kd_tree.build_kdtree import build_kdtree
54
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
@@ -11,7 +10,7 @@ def main() -> None:
1110
in a 10-dimensional hypercube and performing a nearest neighbor search.
1211
"""
1312
num_points: int = 5000
14-
cube_size: int = 10
13+
cube_size: float = 10.0 # Size of the hypercube (edge length)
1514
num_dimensions: int = 10
1615

1716
# Generate random points within the hypercube
@@ -20,7 +19,7 @@ def main() -> None:
2019

2120
# Generate a random query point within the same space
2221
rng = np.random.default_rng()
23-
query_point: List[float] = rng.random(num_dimensions).tolist()
22+
query_point: list[float] = rng.random(num_dimensions).tolist()
2423

2524
# Perform nearest neighbor search
2625
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
@@ -33,6 +32,5 @@ def main() -> None:
3332
print(f"Distance: {nearest_dist:.4f}")
3433
print(f"Nodes visited: {nodes_visited}")
3534

36-
3735
if __name__ == "__main__":
3836
main()
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
import numpy as np
2-
from typing import Union
32

4-
5-
def hypercube_points(
6-
num_points: int, hypercube_size: Union[int, float], num_dimensions: int
7-
) -> np.ndarray:
3+
def hypercube_points(num_points: int, hypercube_size: float, num_dimensions: int) -> np.ndarray:
84
"""
95
Generates random points uniformly distributed within an n-dimensional hypercube.
106
117
Args:
12-
num_points (int): The number of random points to generate.
13-
hypercube_size (Union[int, float]): The size of the hypercube (side length).
14-
num_dimensions (int): The number of dimensions of the hypercube.
8+
num_points (int): Number of points to generate.
9+
hypercube_size (float): Size of the hypercube.
10+
num_dimensions (int): Number of dimensions of the hypercube.
1511
1612
Returns:
17-
np.ndarray: An array of shape (num_points, num_dimensions) with the generated points.
13+
np.ndarray: An array of shape (num_points, num_dimensions) with generated points.
1814
"""
1915
rng = np.random.default_rng()
2016
return hypercube_size * rng.random((num_points, num_dimensions))

data_structures/kd_tree/kd_node.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,23 @@
1-
from typing import List, Optional
2-
1+
from typing import Optional
32

43
class KDNode:
54
"""
65
Represents a node in a KD-Tree.
76
87
Attributes:
9-
point (List[float]): The k-dimensional point stored in this node.
10-
left (Optional[KDNode]): The left subtree of this node.
11-
right (Optional[KDNode]): The right subtree of this node.
8+
point (list[float]): The point stored in this node.
9+
left (Optional[KDNode]): The left child node.
10+
right (Optional[KDNode]): The right child node.
1211
"""
1312

14-
def __init__(
15-
self,
16-
point: List[float],
17-
left: Optional["KDNode"] = None,
18-
right: Optional["KDNode"] = None,
19-
) -> None:
13+
def __init__(self, point: list[float], left: Optional['KDNode'] = None, right: Optional['KDNode'] = None) -> None:
2014
"""
21-
Initializes a KDNode with a point and optional left and right children.
15+
Initializes a KDNode with the given point and child nodes.
2216
2317
Args:
24-
point (List[float]): The k-dimensional point to be stored in this node.
25-
left (Optional[KDNode]): The left subtree of this node. Defaults to None.
26-
right (Optional[KDNode]): The right subtree of this node. Defaults to None.
18+
point (list[float]): The point stored in this node.
19+
left (Optional[KDNode]): The left child node.
20+
right (Optional[KDNode]): The right child node.
2721
"""
2822
self.point = point
2923
self.left = left

data_structures/kd_tree/nearest_neighbour_search.py

+12-23
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,31 @@
1-
from typing import Optional, List, Tuple
2-
from .kd_node import KDNode
1+
from typing import Optional
2+
from data_structures.kd_tree.kd_node import KDNode
33

4-
5-
def nearest_neighbour_search(
6-
root: Optional[KDNode], query_point: List[float]
7-
) -> Tuple[Optional[List[float]], float, int]:
4+
def nearest_neighbour_search(root: Optional[KDNode], query_point: list[float]) -> tuple[Optional[list[float]], float, int]:
85
"""
96
Performs a nearest neighbor search in a KD-Tree for a given query point.
107
118
Args:
129
root (Optional[KDNode]): The root node of the KD-Tree.
13-
query_point (List[float]): The point for which the nearest neighbor is being searched.
10+
query_point (list[float]): The point for which the nearest neighbor is being searched.
1411
1512
Returns:
16-
Tuple[Optional[List[float]], float, int]:
13+
tuple[Optional[list[float]], float, int]:
1714
- The nearest point found in the KD-Tree to the query point.
1815
- The squared distance to the nearest point.
1916
- The number of nodes visited during the search.
2017
"""
21-
nearest_point: Optional[List[float]] = None
18+
nearest_point: Optional[list[float]] = None
2219
nearest_dist: float = float("inf")
2320
nodes_visited: int = 0
2421

2522
def search(node: Optional[KDNode], depth: int = 0) -> None:
2623
"""
27-
Recursively searches the KD-Tree to find the nearest point to the query point.
24+
Recursively searches the KD-Tree for the nearest neighbor.
2825
2926
Args:
30-
node (Optional[KDNode]): The current node being examined.
31-
depth (int): The current depth of the tree, which determines the axis to split on.
32-
33-
Updates:
34-
nearest_point: The closest point found so far in the KD-Tree.
35-
nearest_dist: The squared distance from the query point to the nearest point found.
36-
nodes_visited: The number of nodes visited during the search.
27+
node (Optional[KDNode]): The current node in the KD-Tree.
28+
depth (int): The current depth in the tree.
3729
"""
3830
nonlocal nearest_point, nearest_dist, nodes_visited
3931
if node is None:
@@ -43,18 +35,15 @@ def search(node: Optional[KDNode], depth: int = 0) -> None:
4335

4436
# Calculate the current distance (squared distance)
4537
current_point = node.point
46-
current_dist = sum(
47-
(query_coord - point_coord) ** 2
48-
for query_coord, point_coord in zip(query_point, current_point)
49-
)
38+
current_dist = sum((query_coord - point_coord) ** 2 for query_coord, point_coord in zip(query_point, current_point))
5039

5140
# Update nearest point if the current node is closer
5241
if nearest_point is None or current_dist < nearest_dist:
5342
nearest_point = current_point
5443
nearest_dist = current_dist
5544

5645
# Determine which subtree to search first (based on axis and query point)
57-
k = len(query_point) # dimensionality of points
46+
k = len(query_point) # Dimensionality of points
5847
axis = depth % k
5948

6049
if query_point[axis] <= current_point[axis]:
@@ -67,7 +56,7 @@ def search(node: Optional[KDNode], depth: int = 0) -> None:
6756
# Search the nearer subtree first
6857
search(nearer_subtree, depth + 1)
6958

70-
# If the further subtree has a closer point, search it
59+
# If the further subtree has a closer point
7160
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist:
7261
search(further_subtree, depth + 1)
7362

data_structures/kd_tree/tests/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import unittest
2+
import numpy as np
3+
from data_structures.kd_tree.build_kdtree import build_kdtree
4+
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
5+
from data_structures.kd_tree.kd_node import KDNode
6+
from data_structures.kd_tree.example.hypercube_points import hypercube_points
7+
8+
class TestKDTree(unittest.TestCase):
9+
10+
def setUp(self):
11+
"""
12+
Set up test data.
13+
"""
14+
self.num_points = 10
15+
self.cube_size = 10.0
16+
self.num_dimensions = 2
17+
self.points = hypercube_points(self.num_points, self.cube_size, self.num_dimensions)
18+
self.kdtree = build_kdtree(self.points.tolist())
19+
20+
def test_build_kdtree(self):
21+
"""
22+
Test that KD-Tree is built correctly.
23+
"""
24+
# Check if root is not None
25+
self.assertIsNotNone(self.kdtree)
26+
27+
# Check if root has correct dimensions
28+
self.assertEqual(len(self.kdtree.point), self.num_dimensions)
29+
30+
# Check that the tree is balanced to some extent (simplistic check)
31+
self.assertIsInstance(self.kdtree, KDNode)
32+
33+
def test_nearest_neighbour_search(self):
34+
"""
35+
Test the nearest neighbor search function.
36+
"""
37+
rng = np.random.default_rng()
38+
query_point = rng.random(self.num_dimensions).tolist()
39+
40+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
41+
self.kdtree, query_point
42+
)
43+
44+
# Check that nearest point is not None
45+
self.assertIsNotNone(nearest_point)
46+
47+
# Check that distance is a non-negative number
48+
self.assertGreaterEqual(nearest_dist, 0)
49+
50+
# Check that nodes visited is a non-negative integer
51+
self.assertGreaterEqual(nodes_visited, 0)
52+
53+
def test_edge_cases(self):
54+
"""
55+
Test edge cases such as an empty KD-Tree.
56+
"""
57+
empty_kdtree = build_kdtree([])
58+
query_point = [0.0] * self.num_dimensions
59+
60+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
61+
empty_kdtree, query_point
62+
)
63+
64+
# With an empty KD-Tree, nearest_point should be None
65+
self.assertIsNone(nearest_point)
66+
self.assertEqual(nearest_dist, float("inf"))
67+
self.assertEqual(nodes_visited, 0)
68+
69+
if __name__ == '__main__':
70+
unittest.main()

0 commit comments

Comments
 (0)