Skip to content

Commit e2d0e50

Browse files
committed
Add type hints and doctests for quantum k-means clustering functions
1 parent 7d1d891 commit e2d0e50

File tree

1 file changed

+74
-12
lines changed

1 file changed

+74
-12
lines changed

quantum/quantum_kmeans_clustering.py

+74-12
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,34 @@
44
from sklearn.datasets import make_blobs
55
from sklearn.preprocessing import MinMaxScaler
66

7-
def generate_data(n_samples=100, n_features=2, n_clusters=2):
7+
def generate_data(n_samples: int = 100, n_features: int = 2, n_clusters: int = 2) -> tuple[np.ndarray, np.ndarray]:
8+
"""
9+
Generates synthetic data using the make_blobs function and normalizes it.
10+
11+
:param n_samples: Number of samples to generate.
12+
:param n_features: Number of features for each sample.
13+
:param n_clusters: Number of clusters to generate.
14+
:return: A tuple containing normalized data and labels.
15+
16+
>>> data, labels = generate_data(10, 2, 2)
17+
>>> assert data.shape == (10, 2)
18+
>>> assert len(labels) == 10
19+
"""
820
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42)
921
return MinMaxScaler().fit_transform(data), labels
1022

11-
def quantum_distance(point1, point2):
23+
def quantum_distance(point1: np.ndarray, point2: np.ndarray) -> float:
1224
"""
13-
Quantum circuit explanation:
14-
1. Use a single qubit to encode the distance between two points.
15-
2. Apply Ry rotation based on the normalized Euclidean distance.
16-
3. Measure the qubit to get a probabilistic distance metric.
17-
The probability of measuring |1> correlates with the distance between points.
25+
Computes the quantum distance between two points.
26+
27+
:param point1: First point as a numpy array.
28+
:param point2: Second point as a numpy array.
29+
:return: Quantum distance between the two points.
30+
31+
>>> point_a = np.array([1.0, 2.0])
32+
>>> point_b = np.array([1.5, 2.5])
33+
>>> result = quantum_distance(point_a, point_b)
34+
>>> assert isinstance(result, float)
1835
"""
1936
qubit = cirq.LineQubit(0)
2037
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1)
@@ -28,20 +45,65 @@ def quantum_distance(point1, point2):
2845
result = cirq.Simulator().run(circuit, repetitions=1000)
2946
return result.histogram(key='result').get(1, 0) / 1000
3047

31-
def initialize_centroids(data, k):
48+
def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
49+
"""
50+
Initializes centroids for k-means clustering.
51+
52+
:param data: The dataset from which to initialize centroids.
53+
:param k: The number of centroids to initialize.
54+
:return: An array of initialized centroids.
55+
56+
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
57+
>>> centroids = initialize_centroids(data, 2)
58+
>>> assert centroids.shape == (2, 2)
59+
"""
3260
return data[np.random.choice(len(data), k, replace=False)]
3361

34-
def assign_clusters(data, centroids):
62+
def assign_clusters(data: np.ndarray, centroids: np.ndarray) -> list[list[np.ndarray]]:
63+
"""
64+
Assigns data points to the nearest centroid.
65+
66+
:param data: The dataset to cluster.
67+
:param centroids: The current centroids.
68+
:return: A list of clusters, each containing points assigned to it.
69+
70+
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
71+
>>> centroids = np.array([[1, 2], [5, 6]])
72+
>>> clusters = assign_clusters(data, centroids)
73+
>>> assert len(clusters) == 2
74+
"""
3575
clusters = [[] for _ in range(len(centroids))]
3676
for point in data:
3777
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i]))
3878
clusters[closest].append(point)
3979
return clusters
4080

41-
def recompute_centroids(clusters):
81+
def recompute_centroids(clusters: list[list[np.ndarray]]) -> np.ndarray:
82+
"""
83+
Recomputes the centroids based on the assigned clusters.
84+
85+
:param clusters: A list of clusters, each containing points assigned to it.
86+
:return: An array of newly computed centroids.
87+
88+
>>> clusters = [[np.array([1, 2]), np.array([1, 3])], [np.array([5, 6]), np.array([5, 7])]]
89+
>>> centroids = recompute_centroids(clusters)
90+
>>> assert centroids.shape == (2, 2)
91+
"""
4292
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster])
4393

44-
def quantum_kmeans(data, k, max_iters=10):
94+
def quantum_kmeans(data: np.ndarray, k: int, max_iters: int = 10) -> tuple[np.ndarray, list[list[np.ndarray]]]:
95+
"""
96+
Applies the quantum k-means clustering algorithm.
97+
98+
:param data: The dataset to cluster.
99+
:param k: The number of clusters.
100+
:param max_iters: The maximum number of iterations.
101+
:return: A tuple containing final centroids and clusters.
102+
103+
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
104+
>>> centroids, clusters = quantum_kmeans(data, 2)
105+
>>> assert centroids.shape[0] == 2
106+
"""
45107
centroids = initialize_centroids(data, k)
46108

47109
for _ in range(max_iters):
@@ -76,4 +138,4 @@ def quantum_kmeans(data, k, max_iters=10):
76138
plt.tight_layout()
77139
plt.show()
78140

79-
print(f"Final Centroids:\n{final_centroids}")
141+
print(f"Final Centroids:\n{final_centroids}")

0 commit comments

Comments
 (0)