Skip to content

Commit 7d1d891

Browse files
committed
Add Quantum k-Means Clustering Implementation
1 parent 00e9d86 commit 7d1d891

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

quantum/quantum_kmeans_clustering.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import cirq
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
from sklearn.datasets import make_blobs
5+
from sklearn.preprocessing import MinMaxScaler
6+
7+
def generate_data(n_samples=100, n_features=2, n_clusters=2):
8+
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42)
9+
return MinMaxScaler().fit_transform(data), labels
10+
11+
def quantum_distance(point1, point2):
12+
"""
13+
Quantum circuit explanation:
14+
1. Use a single qubit to encode the distance between two points.
15+
2. Apply Ry rotation based on the normalized Euclidean distance.
16+
3. Measure the qubit to get a probabilistic distance metric.
17+
The probability of measuring |1> correlates with the distance between points.
18+
"""
19+
qubit = cirq.LineQubit(0)
20+
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1)
21+
theta = 2 * np.arcsin(diff)
22+
23+
circuit = cirq.Circuit(
24+
cirq.ry(theta)(qubit),
25+
cirq.measure(qubit, key='result')
26+
)
27+
28+
result = cirq.Simulator().run(circuit, repetitions=1000)
29+
return result.histogram(key='result').get(1, 0) / 1000
30+
31+
def initialize_centroids(data, k):
32+
return data[np.random.choice(len(data), k, replace=False)]
33+
34+
def assign_clusters(data, centroids):
35+
clusters = [[] for _ in range(len(centroids))]
36+
for point in data:
37+
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i]))
38+
clusters[closest].append(point)
39+
return clusters
40+
41+
def recompute_centroids(clusters):
42+
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster])
43+
44+
def quantum_kmeans(data, k, max_iters=10):
45+
centroids = initialize_centroids(data, k)
46+
47+
for _ in range(max_iters):
48+
clusters = assign_clusters(data, centroids)
49+
new_centroids = recompute_centroids(clusters)
50+
if np.allclose(new_centroids, centroids):
51+
break
52+
centroids = new_centroids
53+
54+
return centroids, clusters
55+
56+
# Main execution
57+
n_samples, n_clusters = 10, 2
58+
data, labels = generate_data(n_samples, n_clusters=n_clusters)
59+
60+
plt.figure(figsize=(12, 5))
61+
62+
plt.subplot(121)
63+
plt.scatter(data[:, 0], data[:, 1], c=labels)
64+
plt.title("Generated Data")
65+
66+
final_centroids, final_clusters = quantum_kmeans(data, n_clusters)
67+
68+
plt.subplot(122)
69+
for i, cluster in enumerate(final_clusters):
70+
cluster = np.array(cluster)
71+
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}')
72+
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids')
73+
plt.title("Quantum k-Means Clustering with Cirq")
74+
plt.legend()
75+
76+
plt.tight_layout()
77+
plt.show()
78+
79+
print(f"Final Centroids:\n{final_centroids}")

0 commit comments

Comments
 (0)