Skip to content

Commit 1322921

Browse files
updated tests and used | for type annotations
1 parent ad31f83 commit 1322921

File tree

6 files changed

+78
-77
lines changed

6 files changed

+78
-77
lines changed
+10-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Optional
2-
from .kd_node import KDNode
1+
from data_structures.kd_tree.kd_node import KDNode
32

4-
5-
def build_kdtree(points: list[list[float]], depth: int = 0) -> Optional[KDNode]:
3+
def build_kdtree(
4+
points: list[list[float]], depth: int = 0
5+
) -> KDNode | None:
66
"""
77
Builds a KD-Tree from a list of points.
88
@@ -11,7 +11,7 @@ def build_kdtree(points: list[list[float]], depth: int = 0) -> Optional[KDNode]:
1111
depth (int): The current depth in the tree (used to determine axis for splitting).
1212
1313
Returns:
14-
Optional[KDNode]: The root node of the KD-Tree.
14+
KDNode | None: The root node of the KD-Tree, or None if no points are provided.
1515
"""
1616
if not points:
1717
return None
@@ -24,8 +24,11 @@ def build_kdtree(points: list[list[float]], depth: int = 0) -> Optional[KDNode]:
2424
median_idx = len(points) // 2
2525

2626
# Create node and construct subtrees
27+
left_points = points[:median_idx]
28+
right_points = points[median_idx + 1:]
29+
2730
return KDNode(
2831
point=points[median_idx],
29-
left=build_kdtree(points[:median_idx], depth + 1),
30-
right=build_kdtree(points[median_idx + 1 :], depth + 1),
32+
left=build_kdtree(left_points, depth + 1),
33+
right=build_kdtree(right_points, depth + 1),
3134
)

data_structures/kd_tree/example/example_usage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from hypercube_points import hypercube_points
2+
from data_structures.kd_tree.example.hypercube_points import hypercube_points
33
from data_structures.kd_tree.build_kdtree import build_kdtree
44
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
55

data_structures/kd_tree/example/hypercube_points.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
22

3-
43
def hypercube_points(
5-
num_points: int, hypercube_size: float, num_dimensions: int
4+
num_points: int, hypercube_size: float, num_dimensions: int
65
) -> np.ndarray:
76
"""
87
Generates random points uniformly distributed within an n-dimensional hypercube.
@@ -16,4 +15,5 @@ def hypercube_points(
1615
np.ndarray: An array of shape (num_points, num_dimensions) with generated points.
1716
"""
1817
rng = np.random.default_rng()
19-
return hypercube_size * rng.random((num_points, num_dimensions))
18+
shape = (num_points, num_dimensions)
19+
return hypercube_size * rng.random(shape)

data_structures/kd_tree/kd_node.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Optional
22

3-
43
class KDNode:
54
"""
65
Represents a node in a KD-Tree.

data_structures/kd_tree/nearest_neighbour_search.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
1-
from typing import Optional
21
from data_structures.kd_tree.kd_node import KDNode
32

4-
53
def nearest_neighbour_search(
6-
root: Optional[KDNode], query_point: list[float]
7-
) -> tuple[Optional[list[float]], float, int]:
4+
root: KDNode | None,
5+
query_point: list[float]
6+
) -> tuple[list[float] | None, float, int]:
87
"""
98
Performs a nearest neighbor search in a KD-Tree for a given query point.
109
1110
Args:
12-
root (Optional[KDNode]): The root node of the KD-Tree.
11+
root (KDNode | None): The root node of the KD-Tree.
1312
query_point (list[float]): The point for which the nearest neighbor is being searched.
1413
1514
Returns:
16-
tuple[Optional[list[float]], float, int]:
17-
- The nearest point found in the KD-Tree to the query point.
15+
tuple[list[float] | None, float, int]:
16+
- The nearest point found in the KD-Tree to the query point, or None if no point is found.
1817
- The squared distance to the nearest point.
1918
- The number of nodes visited during the search.
2019
"""
21-
nearest_point: Optional[list[float]] = None
20+
nearest_point: list[float] | None = None
2221
nearest_dist: float = float("inf")
2322
nodes_visited: int = 0
2423

25-
def search(node: Optional[KDNode], depth: int = 0) -> None:
24+
def search(
25+
node: KDNode | None,
26+
depth: int = 0
27+
) -> None:
2628
"""
27-
Recursively searches the KD-Tree for the nearest neighbor.
29+
Recursively searches for the nearest neighbor in the KD-Tree.
2830
2931
Args:
30-
node (Optional[KDNode]): The current node in the KD-Tree.
31-
depth (int): The current depth in the tree.
32+
node (KDNode | None): The current node in the KD-Tree.
33+
depth (int): The current depth in the KD-Tree.
3234
"""
3335
nonlocal nearest_point, nearest_dist, nodes_visited
3436
if node is None:
+50-53
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,70 @@
1-
import unittest
21
import numpy as np
32
from data_structures.kd_tree.build_kdtree import build_kdtree
43
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
54
from data_structures.kd_tree.kd_node import KDNode
65
from data_structures.kd_tree.example.hypercube_points import hypercube_points
76

7+
def test_build_kdtree():
8+
"""
9+
Test that KD-Tree is built correctly.
10+
"""
11+
num_points = 10
12+
cube_size = 10.0
13+
num_dimensions = 2
14+
points = hypercube_points(num_points, cube_size, num_dimensions)
15+
kdtree = build_kdtree(points.tolist())
816

9-
class TestKDTree(unittest.TestCase):
10-
def setUp(self):
11-
"""
12-
Set up test data.
13-
"""
14-
self.num_points = 10
15-
self.cube_size = 10.0
16-
self.num_dimensions = 2
17-
self.points = hypercube_points(
18-
self.num_points, self.cube_size, self.num_dimensions
19-
)
20-
self.kdtree = build_kdtree(self.points.tolist())
17+
# Check if root is not None
18+
assert kdtree is not None
2119

22-
def test_build_kdtree(self):
23-
"""
24-
Test that KD-Tree is built correctly.
25-
"""
26-
# Check if root is not None
27-
self.assertIsNotNone(self.kdtree)
20+
# Check if root has correct dimensions
21+
assert len(kdtree.point) == num_dimensions
2822

29-
# Check if root has correct dimensions
30-
self.assertEqual(len(self.kdtree.point), self.num_dimensions)
23+
# Check that the tree is balanced to some extent (simplistic check)
24+
assert isinstance(kdtree, KDNode)
3125

32-
# Check that the tree is balanced to some extent (simplistic check)
33-
self.assertIsInstance(self.kdtree, KDNode)
26+
def test_nearest_neighbour_search():
27+
"""
28+
Test the nearest neighbor search function.
29+
"""
30+
num_points = 10
31+
cube_size = 10.0
32+
num_dimensions = 2
33+
points = hypercube_points(num_points, cube_size, num_dimensions)
34+
kdtree = build_kdtree(points.tolist())
3435

35-
def test_nearest_neighbour_search(self):
36-
"""
37-
Test the nearest neighbor search function.
38-
"""
39-
rng = np.random.default_rng()
40-
query_point = rng.random(self.num_dimensions).tolist()
36+
rng = np.random.default_rng()
37+
query_point = rng.random(num_dimensions).tolist()
4138

42-
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
43-
self.kdtree, query_point
44-
)
39+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
40+
kdtree, query_point
41+
)
4542

46-
# Check that nearest point is not None
47-
self.assertIsNotNone(nearest_point)
43+
# Check that nearest point is not None
44+
assert nearest_point is not None
4845

49-
# Check that distance is a non-negative number
50-
self.assertGreaterEqual(nearest_dist, 0)
46+
# Check that distance is a non-negative number
47+
assert nearest_dist >= 0
5148

52-
# Check that nodes visited is a non-negative integer
53-
self.assertGreaterEqual(nodes_visited, 0)
49+
# Check that nodes visited is a non-negative integer
50+
assert nodes_visited >= 0
5451

55-
def test_edge_cases(self):
56-
"""
57-
Test edge cases such as an empty KD-Tree.
58-
"""
59-
empty_kdtree = build_kdtree([])
60-
query_point = [0.0] * self.num_dimensions
52+
def test_edge_cases():
53+
"""
54+
Test edge cases such as an empty KD-Tree.
55+
"""
56+
empty_kdtree = build_kdtree([])
57+
query_point = [0.0] * 2 # Using a default 2D query point
6158

62-
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
63-
empty_kdtree, query_point
64-
)
65-
66-
# With an empty KD-Tree, nearest_point should be None
67-
self.assertIsNone(nearest_point)
68-
self.assertEqual(nearest_dist, float("inf"))
69-
self.assertEqual(nodes_visited, 0)
59+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
60+
empty_kdtree, query_point
61+
)
7062

63+
# With an empty KD-Tree, nearest_point should be None
64+
assert nearest_point is None
65+
assert nearest_dist == float("inf")
66+
assert nodes_visited == 0
7167

7268
if __name__ == "__main__":
73-
unittest.main()
69+
import pytest
70+
pytest.main()

0 commit comments

Comments
 (0)