Skip to content

Commit 0d6985c

Browse files
Implemented KD-Tree Data Structure
1 parent c8e131b commit 0d6985c

File tree

6 files changed

+91
-0
lines changed

6 files changed

+91
-0
lines changed

data_structures/kd_tree/__init__.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .kd_node import KDNode
2+
3+
def build_kdtree(points, depth=0):
4+
if not points:
5+
return None
6+
7+
k = len(points[0]) # dimensionality of the points
8+
axis = depth % k
9+
10+
# Sort point list and choose median as pivot element
11+
points.sort(key=lambda x: x[axis])
12+
median_idx = len(points) // 2
13+
14+
# Create node and construct subtrees
15+
return KDNode(
16+
point = points[median_idx],
17+
left = build_kdtree(points[:median_idx], depth + 1),
18+
right = build_kdtree(points[median_idx + 1:], depth + 1)
19+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
3+
from hypercube_points import hypercube_points
4+
from data_structures.kd_tree.build_kdtree import build_kdtree
5+
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
6+
7+
8+
num_points = 5000
9+
cube_size = 10
10+
num_dimensions = 10
11+
12+
points = hypercube_points(num_points, cube_size, num_dimensions)
13+
hypercube_kdtree = build_kdtree(points.tolist())
14+
15+
query_point = np.random.rand(num_dimensions).tolist()
16+
17+
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(hypercube_kdtree, query_point)
18+
19+
print(f"Query point: {query_point}")
20+
print(f"Nearest point: {nearest_point}")
21+
print(f"Distance: {nearest_dist:.4f}")
22+
print(f"Nodes visited: {nodes_visited}")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import numpy as np
2+
3+
def hypercube_points(num_points, hypercube_size, num_dimensions):
4+
return hypercube_size * np.random.rand(num_points, num_dimensions)

data_structures/kd_tree/kd_node.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
class KDNode:
2+
def __init__(self, point, left = None, right = None):
3+
self.point = point
4+
self.left = left
5+
self.right = right
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
def nearest_neighbour_search(root, query_point):
2+
nearest_point = None
3+
nearest_dist = float('inf')
4+
nodes_visited = 0
5+
6+
def search(node, depth=0):
7+
nonlocal nearest_point, nearest_dist, nodes_visited
8+
if node is None:
9+
return
10+
11+
nodes_visited += 1
12+
13+
# Calculate the current distance (squared distance)
14+
current_point = node.point
15+
current_dist = sum((qp - cp) ** 2 for qp, cp in zip(query_point, current_point))
16+
17+
# Update nearest point if the current node is closer
18+
if nearest_point is None or current_dist < nearest_dist:
19+
nearest_point = current_point
20+
nearest_dist = current_dist
21+
22+
# Determine which subtree to search first (based on axis and query point)
23+
k = len(query_point) # dimensionality of points
24+
axis = depth % k
25+
26+
if query_point[axis] <= current_point[axis]:
27+
nearer_subtree = node.left
28+
further_subtree = node.right
29+
else:
30+
nearer_subtree = node.right
31+
further_subtree = node.left
32+
33+
# Search the nearer subtree first
34+
search(nearer_subtree, depth + 1)
35+
36+
# If the further subtree has a closer point
37+
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist:
38+
search(further_subtree, depth + 1)
39+
40+
search(root, 0)
41+
return nearest_point, nearest_dist, nodes_visited

0 commit comments

Comments
 (0)