Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3655742

Browse files
committedOct 2, 2024·
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent facfce2 commit 3655742

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed
 

‎quantum/quantum_kmeans_clustering.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77

88
def generate_data(n_samples=100, n_features=2, n_clusters=2):
9-
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42)
9+
data, labels = make_blobs(
10+
n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42
11+
)
1012
return MinMaxScaler().fit_transform(data), labels
1113

14+
1215
def quantum_distance(point1, point2):
1316
"""
1417
Computes the quantum distance between two points.
@@ -25,14 +28,12 @@ def quantum_distance(point1, point2):
2528
qubit = cirq.LineQubit(0)
2629
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1)
2730
theta = 2 * np.arcsin(diff)
28-
29-
circuit = cirq.Circuit(
30-
cirq.ry(theta)(qubit),
31-
cirq.measure(qubit, key='result')
32-
)
33-
31+
32+
circuit = cirq.Circuit(cirq.ry(theta)(qubit), cirq.measure(qubit, key="result"))
33+
3434
result = cirq.Simulator().run(circuit, repetitions=1000)
35-
return result.histogram(key='result').get(1, 0) / 1000
35+
return result.histogram(key="result").get(1, 0) / 1000
36+
3637

3738
def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
3839
"""
@@ -48,28 +49,34 @@ def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
4849
"""
4950
return data[np.random.choice(len(data), k, replace=False)]
5051

52+
5153
def assign_clusters(data, centroids):
5254
clusters = [[] for _ in range(len(centroids))]
5355
for point in data:
54-
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i]))
56+
closest = min(
57+
range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i])
58+
)
5559
clusters[closest].append(point)
5660
return clusters
5761

62+
5863
def recompute_centroids(clusters):
5964
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster])
6065

66+
6167
def quantum_kmeans(data, k, max_iters=10):
6268
centroids = initialize_centroids(data, k)
63-
69+
6470
for _ in range(max_iters):
6571
clusters = assign_clusters(data, centroids)
6672
new_centroids = recompute_centroids(clusters)
6773
if np.allclose(new_centroids, centroids):
6874
break
6975
centroids = new_centroids
70-
76+
7177
return centroids, clusters
7278

79+
7380
# Main execution
7481
n_samples, n_clusters = 10, 2
7582
data, labels = generate_data(n_samples, n_clusters=n_clusters)
@@ -85,12 +92,20 @@ def quantum_kmeans(data, k, max_iters=10):
8592
plt.subplot(122)
8693
for i, cluster in enumerate(final_clusters):
8794
cluster = np.array(cluster)
88-
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}')
89-
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids')
95+
plt.scatter(cluster[:, 0], cluster[:, 1], label=f"Cluster {i+1}")
96+
plt.scatter(
97+
final_centroids[:, 0],
98+
final_centroids[:, 1],
99+
color="red",
100+
marker="x",
101+
s=200,
102+
linewidths=3,
103+
label="Centroids",
104+
)
90105
plt.title("Quantum k-Means Clustering with Cirq")
91106
plt.legend()
92107

93108
plt.tight_layout()
94109
plt.show()
95110

96-
print(f"Final Centroids:\n{final_centroids}")
111+
print(f"Final Centroids:\n{final_centroids}")

0 commit comments

Comments
 (0)
Please sign in to comment.