Skip to content

Commit 027549b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7d1d891 commit 027549b

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

quantum/quantum_kmeans_clustering.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
from sklearn.datasets import make_blobs
55
from sklearn.preprocessing import MinMaxScaler
66

7+
78
def generate_data(n_samples=100, n_features=2, n_clusters=2):
8-
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42)
9+
data, labels = make_blobs(
10+
n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42
11+
)
912
return MinMaxScaler().fit_transform(data), labels
1013

14+
1115
def quantum_distance(point1, point2):
1216
"""
1317
Quantum circuit explanation:
@@ -19,40 +23,44 @@ def quantum_distance(point1, point2):
1923
qubit = cirq.LineQubit(0)
2024
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1)
2125
theta = 2 * np.arcsin(diff)
22-
23-
circuit = cirq.Circuit(
24-
cirq.ry(theta)(qubit),
25-
cirq.measure(qubit, key='result')
26-
)
27-
26+
27+
circuit = cirq.Circuit(cirq.ry(theta)(qubit), cirq.measure(qubit, key="result"))
28+
2829
result = cirq.Simulator().run(circuit, repetitions=1000)
29-
return result.histogram(key='result').get(1, 0) / 1000
30+
return result.histogram(key="result").get(1, 0) / 1000
31+
3032

3133
def initialize_centroids(data, k):
3234
return data[np.random.choice(len(data), k, replace=False)]
3335

36+
3437
def assign_clusters(data, centroids):
3538
clusters = [[] for _ in range(len(centroids))]
3639
for point in data:
37-
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i]))
40+
closest = min(
41+
range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i])
42+
)
3843
clusters[closest].append(point)
3944
return clusters
4045

46+
4147
def recompute_centroids(clusters):
4248
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster])
4349

50+
4451
def quantum_kmeans(data, k, max_iters=10):
4552
centroids = initialize_centroids(data, k)
46-
53+
4754
for _ in range(max_iters):
4855
clusters = assign_clusters(data, centroids)
4956
new_centroids = recompute_centroids(clusters)
5057
if np.allclose(new_centroids, centroids):
5158
break
5259
centroids = new_centroids
53-
60+
5461
return centroids, clusters
5562

63+
5664
# Main execution
5765
n_samples, n_clusters = 10, 2
5866
data, labels = generate_data(n_samples, n_clusters=n_clusters)
@@ -68,12 +76,20 @@ def quantum_kmeans(data, k, max_iters=10):
6876
plt.subplot(122)
6977
for i, cluster in enumerate(final_clusters):
7078
cluster = np.array(cluster)
71-
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}')
72-
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids')
79+
plt.scatter(cluster[:, 0], cluster[:, 1], label=f"Cluster {i+1}")
80+
plt.scatter(
81+
final_centroids[:, 0],
82+
final_centroids[:, 1],
83+
color="red",
84+
marker="x",
85+
s=200,
86+
linewidths=3,
87+
label="Centroids",
88+
)
7389
plt.title("Quantum k-Means Clustering with Cirq")
7490
plt.legend()
7591

7692
plt.tight_layout()
7793
plt.show()
7894

79-
print(f"Final Centroids:\n{final_centroids}")
95+
print(f"Final Centroids:\n{final_centroids}")

0 commit comments

Comments
 (0)