|
1 | 1 | import numpy as np
|
| 2 | +import pytest |
2 | 3 |
|
3 | 4 | from data_structures.kd_tree.build_kdtree import build_kdtree
|
4 | 5 | from data_structures.kd_tree.example.hypercube_points import hypercube_points
|
5 | 6 | from data_structures.kd_tree.kd_node import KDNode
|
6 | 7 | from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search
|
7 | 8 |
|
8 | 9 |
|
9 |
| -def test_build_kdtree(): |
| 10 | +@pytest.mark.parametrize( |
| 11 | + "num_points, cube_size, num_dimensions, depth, expected_result", |
| 12 | + [ |
| 13 | + (0, 10.0, 2, 0, None), # Empty points list |
| 14 | + (10, 10.0, 2, 2, KDNode), # Depth = 2, 2D points |
| 15 | + (10, 10.0, 3, -2, KDNode), # Depth = -2, 3D points |
| 16 | + ], |
| 17 | +) |
| 18 | +def test_build_kdtree(num_points, cube_size, num_dimensions, depth, expected_result): |
10 | 19 | """
|
11 | 20 | Test that KD-Tree is built correctly.
|
| 21 | +
|
| 22 | + Cases: |
| 23 | + - Empty points list. |
| 24 | + - Positive depth value. |
| 25 | + - Negative depth value. |
12 | 26 | """
|
13 |
| - num_points = 10 |
14 |
| - cube_size = 10.0 |
15 |
| - num_dimensions = 2 |
16 |
| - points = hypercube_points(num_points, cube_size, num_dimensions) |
17 |
| - kdtree = build_kdtree(points.tolist()) |
| 27 | + points = hypercube_points(num_points, cube_size, num_dimensions).tolist() \ |
| 28 | + if num_points > 0 \ |
| 29 | + else [] |
| 30 | + |
| 31 | + kdtree = build_kdtree(points, depth = depth) |
18 | 32 |
|
19 |
| - # Check if root is not None |
20 |
| - assert kdtree is not None |
| 33 | + if expected_result is None: |
| 34 | + # Empty points list case |
| 35 | + assert kdtree is None, f"Expected None for empty points list, got {kdtree}" |
| 36 | + else: |
| 37 | + # Check if root node is not None |
| 38 | + assert kdtree is not None, "Expected a KDNode, got None" |
21 | 39 |
|
22 |
| - # Check if root has correct dimensions |
23 |
| - assert len(kdtree.point) == num_dimensions |
| 40 | + # Check if root has correct dimensions |
| 41 | + assert len(kdtree.point) == num_dimensions, \ |
| 42 | + f"Expected point dimension {num_dimensions}, got {len(kdtree.point)}" |
24 | 43 |
|
25 |
| - # Check that the tree is balanced to some extent (simplistic check) |
26 |
| - assert isinstance(kdtree, KDNode) |
| 44 | + # Check that the tree is balanced to some extent (simplistic check) |
| 45 | + assert isinstance(kdtree, KDNode), f"Expected KDNode instance, got {type(kdtree)}" |
27 | 46 |
|
28 | 47 |
|
29 | 48 | def test_nearest_neighbour_search():
|
|
0 commit comments