-
-
Notifications
You must be signed in to change notification settings - Fork 46.8k
Implemented KD Tree Data Structure #11532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
0d6985c
6665d23
6b3d47e
4203cda
3222bd3
a41ae5b
1668d73
81d6917
6cddcbd
8b238d1
cd1dd9f
ead2838
543584c
ad31f83
1322921
4608a9f
7c1aa7e
ba24e75
05975a3
31782d1
6a9b3e1
2fd24d4
2cf9d92
a3803ee
f1f5862
ec6559d
5c07a1a
3c09ac1
bab43e7
a10ff15
d77a285
0426806
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import List, Optional | ||
from .kd_node import KDNode | ||
|
||
def build_kdtree(points: List[List[float]], depth: int = 0) -> Optional[KDNode]: | ||
Check failure on line 4 in data_structures/kd_tree/build_kdtree.py
|
||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Builds a KD-Tree from a set of k-dimensional points. | ||
|
||
Args: | ||
points (List[List[float]]): A list of k-dimensional points (each point is a list of floats). | ||
depth (int): The current depth in the tree. Used to determine the splitting axis. Defaults to 0. | ||
|
||
Returns: | ||
Optional[KDNode]: The root of the KD-Tree or None if the input list is empty. | ||
""" | ||
if not points: | ||
return None | ||
|
||
k = len(points[0]) # Dimensionality of the points | ||
axis = depth % k | ||
|
||
# Sort point list and choose median as pivot element | ||
points.sort(key=lambda point: point[axis]) | ||
median_idx = len(points) // 2 | ||
|
||
# Create node and construct subtrees | ||
return KDNode( | ||
point=points[median_idx], | ||
left=build_kdtree(points[:median_idx], depth + 1), | ||
right=build_kdtree(points[median_idx + 1:], depth + 1), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
from typing import List | ||
from hypercube_points import hypercube_points | ||
from data_structures.kd_tree.build_kdtree import build_kdtree | ||
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
|
||
def main() -> None: | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Demonstrates the use of KD-Tree by building it from random points | ||
in a 10-dimensional hypercube and performing a nearest neighbor search. | ||
""" | ||
num_points: int = 5000 | ||
cube_size: int = 10 | ||
num_dimensions: int = 10 | ||
|
||
# Generate random points within the hypercube | ||
points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions) | ||
hypercube_kdtree = build_kdtree(points.tolist()) | ||
|
||
# Generate a random query point within the same space | ||
rng = np.random.default_rng() | ||
query_point: List[float] = rng.random(num_dimensions).tolist() | ||
|
||
# Perform nearest neighbor search | ||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
hypercube_kdtree, query_point | ||
) | ||
|
||
# Print the results | ||
print(f"Query point: {query_point}") | ||
print(f"Nearest point: {nearest_point}") | ||
print(f"Distance: {nearest_dist:.4f}") | ||
print(f"Nodes visited: {nodes_visited}") | ||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import numpy as np | ||
from typing import Union | ||
|
||
def hypercube_points(num_points: int, hypercube_size: Union[int, float], num_dimensions: int) -> np.ndarray: | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Generates random points uniformly distributed within an n-dimensional hypercube. | ||
|
||
Args: | ||
num_points (int): The number of random points to generate. | ||
hypercube_size (Union[int, float]): The size of the hypercube (side length). | ||
num_dimensions (int): The number of dimensions of the hypercube. | ||
|
||
Returns: | ||
np.ndarray: An array of shape (num_points, num_dimensions) with the generated points. | ||
""" | ||
rng = np.random.default_rng() | ||
return hypercube_size * rng.random((num_points, num_dimensions)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import List, Optional | ||
|
||
class KDNode: | ||
""" | ||
Represents a node in a KD-Tree. | ||
|
||
Attributes: | ||
point (List[float]): The k-dimensional point stored in this node. | ||
left (Optional[KDNode]): The left subtree of this node. | ||
right (Optional[KDNode]): The right subtree of this node. | ||
""" | ||
|
||
def __init__(self, point: List[float], left: Optional['KDNode'] = None, right: Optional['KDNode'] = None) -> None: | ||
""" | ||
Initializes a KDNode with a point and optional left and right children. | ||
|
||
Args: | ||
point (List[float]): The k-dimensional point to be stored in this node. | ||
left (Optional[KDNode]): The left subtree of this node. Defaults to None. | ||
right (Optional[KDNode]): The right subtree of this node. Defaults to None. | ||
""" | ||
self.point = point | ||
self.left = left | ||
self.right = right |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Optional, List, Tuple | ||
from .kd_node import KDNode | ||
|
||
def nearest_neighbour_search(root: Optional[KDNode], query_point: List[float]) -> Tuple[Optional[List[float]], float, int]: | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Performs a nearest neighbor search in a KD-Tree for a given query point. | ||
|
||
Args: | ||
root (Optional[KDNode]): The root node of the KD-Tree. | ||
query_point (List[float]): The point for which the nearest neighbor is being searched. | ||
|
||
Returns: | ||
Tuple[Optional[List[float]], float, int]: | ||
- The nearest point found in the KD-Tree to the query point. | ||
- The squared distance to the nearest point. | ||
- The number of nodes visited during the search. | ||
""" | ||
nearest_point: Optional[List[float]] = None | ||
nearest_dist: float = float("inf") | ||
nodes_visited: int = 0 | ||
|
||
def search(node: Optional[KDNode], depth: int = 0) -> None: | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nonlocal nearest_point, nearest_dist, nodes_visited | ||
if node is None: | ||
return | ||
|
||
nodes_visited += 1 | ||
|
||
# Calculate the current distance (squared distance) | ||
current_point = node.point | ||
current_dist = sum((query_coord - point_coord) ** 2 for query_coord, point_coord in zip(query_point, current_point)) | ||
|
||
# Update nearest point if the current node is closer | ||
if nearest_point is None or current_dist < nearest_dist: | ||
nearest_point = current_point | ||
nearest_dist = current_dist | ||
|
||
# Determine which subtree to search first (based on axis and query point) | ||
k = len(query_point) # dimensionality of points | ||
axis = depth % k | ||
|
||
if query_point[axis] <= current_point[axis]: | ||
nearer_subtree = node.left | ||
further_subtree = node.right | ||
else: | ||
nearer_subtree = node.right | ||
further_subtree = node.left | ||
|
||
# Search the nearer subtree first | ||
search(nearer_subtree, depth + 1) | ||
|
||
# If the further subtree has a closer point | ||
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist: | ||
search(further_subtree, depth + 1) | ||
|
||
search(root, 0) | ||
return nearest_point, nearest_dist, nodes_visited |
Uh oh!
There was an error while loading. Please reload this page.