-
-
Notifications
You must be signed in to change notification settings - Fork 46.8k
Implemented KD Tree Data Structure #11532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
cclauss
merged 32 commits into
TheAlgorithms:master
from
Ramy-Badr-Ahmed:feature/kd-tree-implementation
Sep 3, 2024
Merged
Changes from 19 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
0d6985c
Implemented KD-Tree Data Structure
Ramy-Badr-Ahmed 6665d23
Implemented KD-Tree Data Structure. updated DIRECTORY.md.
Ramy-Badr-Ahmed 6b3d47e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4203cda
Create __init__.py
Ramy-Badr-Ahmed 3222bd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a41ae5b
Replaced legacy `np.random.rand` call with `np.random.Generator` in k…
Ramy-Badr-Ahmed 1668d73
Replaced legacy `np.random.rand` call with `np.random.Generator` in k…
Ramy-Badr-Ahmed 81d6917
added typehints and docstrings
Ramy-Badr-Ahmed 6cddcbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8b238d1
docstring for search()
Ramy-Badr-Ahmed cd1dd9f
Merge remote-tracking branch 'origin/feature/kd-tree-implementation' …
Ramy-Badr-Ahmed ead2838
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 543584c
Added tests. Updated docstrings/typehints
Ramy-Badr-Ahmed ad31f83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1322921
updated tests and used | for type annotations
Ramy-Badr-Ahmed 4608a9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7c1aa7e
E501 for build_kdtree.py, hypercube_points.py, nearest_neighbour_sear…
Ramy-Badr-Ahmed ba24e75
I001 for example_usage.py and test_kdtree.py
Ramy-Badr-Ahmed 05975a3
I001 for example_usage.py and test_kdtree.py
Ramy-Badr-Ahmed 31782d1
Update data_structures/kd_tree/build_kdtree.py
Ramy-Badr-Ahmed 6a9b3e1
Update data_structures/kd_tree/example/hypercube_points.py
Ramy-Badr-Ahmed 2fd24d4
Update data_structures/kd_tree/example/hypercube_points.py
Ramy-Badr-Ahmed 2cf9d92
Added new test cases requested in Review. Refactored the test_build_k…
Ramy-Badr-Ahmed a3803ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f1f5862
Considered ruff errors
Ramy-Badr-Ahmed ec6559d
Merge remote-tracking branch 'origin/feature/kd-tree-implementation' …
Ramy-Badr-Ahmed 5c07a1a
Considered ruff errors
Ramy-Badr-Ahmed 3c09ac1
Apply suggestions from code review
cclauss bab43e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a10ff15
Update kd_node.py
cclauss d77a285
imported annotations from __future__
Ramy-Badr-Ahmed 0426806
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from data_structures.kd_tree.kd_node import KDNode | ||
|
||
|
||
def build_kdtree(points: list[list[float]], depth: int = 0) -> KDNode | None: | ||
""" | ||
Builds a KD-Tree from a list of points. | ||
|
||
Args: | ||
points (list[list[float]]): The list of points to build the KD-Tree from. | ||
depth (int): The current depth in the tree | ||
(used to determine axis for splitting). | ||
|
||
Returns: | ||
KDNode | None: The root node of the KD-Tree, | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
or None if no points are provided. | ||
""" | ||
if not points: | ||
return None | ||
|
||
k = len(points[0]) # Dimensionality of the points | ||
axis = depth % k | ||
|
||
# Sort point list and choose median as pivot element | ||
points.sort(key=lambda point: point[axis]) | ||
median_idx = len(points) // 2 | ||
|
||
# Create node and construct subtrees | ||
left_points = points[:median_idx] | ||
right_points = points[median_idx + 1 :] | ||
|
||
return KDNode( | ||
point=points[median_idx], | ||
left=build_kdtree(left_points, depth + 1), | ||
right=build_kdtree(right_points, depth + 1), | ||
) |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
|
||
from data_structures.kd_tree.build_kdtree import build_kdtree | ||
from data_structures.kd_tree.example.hypercube_points import hypercube_points | ||
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
|
||
|
||
def main() -> None: | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Demonstrates the use of KD-Tree by building it from random points | ||
in a 10-dimensional hypercube and performing a nearest neighbor search. | ||
""" | ||
num_points: int = 5000 | ||
cube_size: float = 10.0 # Size of the hypercube (edge length) | ||
num_dimensions: int = 10 | ||
|
||
# Generate random points within the hypercube | ||
points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions) | ||
hypercube_kdtree = build_kdtree(points.tolist()) | ||
|
||
# Generate a random query point within the same space | ||
rng = np.random.default_rng() | ||
query_point: list[float] = rng.random(num_dimensions).tolist() | ||
|
||
# Perform nearest neighbor search | ||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
hypercube_kdtree, query_point | ||
) | ||
|
||
# Print the results | ||
print(f"Query point: {query_point}") | ||
print(f"Nearest point: {nearest_point}") | ||
print(f"Distance: {nearest_dist:.4f}") | ||
print(f"Nodes visited: {nodes_visited}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import numpy as np | ||
|
||
|
||
def hypercube_points( | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_points: int, hypercube_size: float, num_dimensions: int | ||
) -> np.ndarray: | ||
""" | ||
Generates random points uniformly distributed within an n-dimensional hypercube. | ||
|
||
Args: | ||
num_points (int): Number of points to generate. | ||
hypercube_size (float): Size of the hypercube. | ||
num_dimensions (int): Number of dimensions of the hypercube. | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
np.ndarray: An array of shape (num_points, num_dimensions) | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with generated points. | ||
""" | ||
rng = np.random.default_rng() | ||
shape = (num_points, num_dimensions) | ||
return hypercube_size * rng.random(shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Optional | ||
|
||
|
||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class KDNode: | ||
""" | ||
Represents a node in a KD-Tree. | ||
|
||
Attributes: | ||
point (list[float]): The point stored in this node. | ||
left (Optional[KDNode]): The left child node. | ||
right (Optional[KDNode]): The right child node. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
point: list[float], | ||
left: Optional["KDNode"] = None, | ||
right: Optional["KDNode"] = None, | ||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
""" | ||
Initializes a KDNode with the given point and child nodes. | ||
|
||
Args: | ||
point (list[float]): The point stored in this node. | ||
left (Optional[KDNode]): The left child node. | ||
right (Optional[KDNode]): The right child node. | ||
""" | ||
self.point = point | ||
self.left = left | ||
self.right = right |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from data_structures.kd_tree.kd_node import KDNode | ||
|
||
|
||
def nearest_neighbour_search( | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
root: KDNode | None, query_point: list[float] | ||
) -> tuple[list[float] | None, float, int]: | ||
""" | ||
Performs a nearest neighbor search in a KD-Tree for a given query point. | ||
|
||
Args: | ||
root (KDNode | None): The root node of the KD-Tree. | ||
query_point (list[float]): The point for which the nearest neighbor | ||
is being searched. | ||
|
||
Returns: | ||
tuple[list[float] | None, float, int]: | ||
- The nearest point found in the KD-Tree to the query point, | ||
or None if no point is found. | ||
- The squared distance to the nearest point. | ||
- The number of nodes visited during the search. | ||
""" | ||
nearest_point: list[float] | None = None | ||
nearest_dist: float = float("inf") | ||
nodes_visited: int = 0 | ||
|
||
def search(node: KDNode | None, depth: int = 0) -> None: | ||
""" | ||
Recursively searches for the nearest neighbor in the KD-Tree. | ||
|
||
Args: | ||
node (KDNode | None): The current node in the KD-Tree. | ||
depth (int): The current depth in the KD-Tree. | ||
cclauss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
nonlocal nearest_point, nearest_dist, nodes_visited | ||
if node is None: | ||
return | ||
|
||
nodes_visited += 1 | ||
|
||
# Calculate the current distance (squared distance) | ||
current_point = node.point | ||
current_dist = sum( | ||
(query_coord - point_coord) ** 2 | ||
for query_coord, point_coord in zip(query_point, current_point) | ||
) | ||
|
||
# Update nearest point if the current node is closer | ||
if nearest_point is None or current_dist < nearest_dist: | ||
nearest_point = current_point | ||
nearest_dist = current_dist | ||
|
||
# Determine which subtree to search first (based on axis and query point) | ||
k = len(query_point) # Dimensionality of points | ||
axis = depth % k | ||
|
||
if query_point[axis] <= current_point[axis]: | ||
nearer_subtree = node.left | ||
further_subtree = node.right | ||
else: | ||
nearer_subtree = node.right | ||
further_subtree = node.left | ||
|
||
# Search the nearer subtree first | ||
search(nearer_subtree, depth + 1) | ||
|
||
# If the further subtree has a closer point | ||
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist: | ||
search(further_subtree, depth + 1) | ||
|
||
search(root, 0) | ||
return nearest_point, nearest_dist, nodes_visited |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
|
||
from data_structures.kd_tree.build_kdtree import build_kdtree | ||
from data_structures.kd_tree.example.hypercube_points import hypercube_points | ||
from data_structures.kd_tree.kd_node import KDNode | ||
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search | ||
|
||
|
||
def test_build_kdtree(): | ||
""" | ||
Test that KD-Tree is built correctly. | ||
""" | ||
num_points = 10 | ||
cube_size = 10.0 | ||
num_dimensions = 2 | ||
points = hypercube_points(num_points, cube_size, num_dimensions) | ||
kdtree = build_kdtree(points.tolist()) | ||
Ramy-Badr-Ahmed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Check if root is not None | ||
assert kdtree is not None | ||
|
||
# Check if root has correct dimensions | ||
assert len(kdtree.point) == num_dimensions | ||
|
||
# Check that the tree is balanced to some extent (simplistic check) | ||
assert isinstance(kdtree, KDNode) | ||
|
||
|
||
def test_nearest_neighbour_search(): | ||
""" | ||
Test the nearest neighbor search function. | ||
""" | ||
num_points = 10 | ||
cube_size = 10.0 | ||
num_dimensions = 2 | ||
points = hypercube_points(num_points, cube_size, num_dimensions) | ||
kdtree = build_kdtree(points.tolist()) | ||
|
||
rng = np.random.default_rng() | ||
query_point = rng.random(num_dimensions).tolist() | ||
|
||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
kdtree, query_point | ||
) | ||
|
||
# Check that nearest point is not None | ||
assert nearest_point is not None | ||
|
||
# Check that distance is a non-negative number | ||
assert nearest_dist >= 0 | ||
|
||
# Check that nodes visited is a non-negative integer | ||
assert nodes_visited >= 0 | ||
|
||
|
||
def test_edge_cases(): | ||
""" | ||
Test edge cases such as an empty KD-Tree. | ||
""" | ||
empty_kdtree = build_kdtree([]) | ||
query_point = [0.0] * 2 # Using a default 2D query point | ||
|
||
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search( | ||
empty_kdtree, query_point | ||
) | ||
|
||
# With an empty KD-Tree, nearest_point should be None | ||
assert nearest_point is None | ||
assert nearest_dist == float("inf") | ||
assert nodes_visited == 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
import pytest | ||
|
||
pytest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.