Skip to content

Commit 89f5f80

Browse files
committed
Merge branch 'quantum-kmeans-clustering' of https://github.com/RahulPatnaik/Python into quantum-kmeans-clustering
2 parents e2d0e50 + 027549b commit 89f5f80

File tree

1 file changed

+27
-63
lines changed

1 file changed

+27
-63
lines changed

quantum/quantum_kmeans_clustering.py

+27-63
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,12 @@
44
from sklearn.datasets import make_blobs
55
from sklearn.preprocessing import MinMaxScaler
66

7-
def generate_data(n_samples: int = 100, n_features: int = 2, n_clusters: int = 2) -> tuple[np.ndarray, np.ndarray]:
8-
"""
9-
Generates synthetic data using the make_blobs function and normalizes it.
10-
11-
:param n_samples: Number of samples to generate.
12-
:param n_features: Number of features for each sample.
13-
:param n_clusters: Number of clusters to generate.
14-
:return: A tuple containing normalized data and labels.
157

16-
>>> data, labels = generate_data(10, 2, 2)
17-
>>> assert data.shape == (10, 2)
18-
>>> assert len(labels) == 10
19-
"""
8+
def generate_data(n_samples=100, n_features=2, n_clusters=2):
209
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42)
2110
return MinMaxScaler().fit_transform(data), labels
2211

23-
def quantum_distance(point1: np.ndarray, point2: np.ndarray) -> float:
12+
def quantum_distance(point1, point2):
2413
"""
2514
Computes the quantum distance between two points.
2615
@@ -36,14 +25,12 @@ def quantum_distance(point1: np.ndarray, point2: np.ndarray) -> float:
3625
qubit = cirq.LineQubit(0)
3726
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1)
3827
theta = 2 * np.arcsin(diff)
39-
40-
circuit = cirq.Circuit(
41-
cirq.ry(theta)(qubit),
42-
cirq.measure(qubit, key='result')
43-
)
44-
28+
29+
circuit = cirq.Circuit(cirq.ry(theta)(qubit), cirq.measure(qubit, key="result"))
30+
4531
result = cirq.Simulator().run(circuit, repetitions=1000)
46-
return result.histogram(key='result').get(1, 0) / 1000
32+
return result.histogram(key="result").get(1, 0) / 1000
33+
4734

4835
def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
4936
"""
@@ -59,62 +46,31 @@ def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
5946
"""
6047
return data[np.random.choice(len(data), k, replace=False)]
6148

62-
def assign_clusters(data: np.ndarray, centroids: np.ndarray) -> list[list[np.ndarray]]:
63-
"""
64-
Assigns data points to the nearest centroid.
65-
66-
:param data: The dataset to cluster.
67-
:param centroids: The current centroids.
68-
:return: A list of clusters, each containing points assigned to it.
69-
70-
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
71-
>>> centroids = np.array([[1, 2], [5, 6]])
72-
>>> clusters = assign_clusters(data, centroids)
73-
>>> assert len(clusters) == 2
74-
"""
49+
def assign_clusters(data, centroids):
7550
clusters = [[] for _ in range(len(centroids))]
7651
for point in data:
77-
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i]))
52+
closest = min(
53+
range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i])
54+
)
7855
clusters[closest].append(point)
7956
return clusters
8057

81-
def recompute_centroids(clusters: list[list[np.ndarray]]) -> np.ndarray:
82-
"""
83-
Recomputes the centroids based on the assigned clusters.
84-
85-
:param clusters: A list of clusters, each containing points assigned to it.
86-
:return: An array of newly computed centroids.
87-
88-
>>> clusters = [[np.array([1, 2]), np.array([1, 3])], [np.array([5, 6]), np.array([5, 7])]]
89-
>>> centroids = recompute_centroids(clusters)
90-
>>> assert centroids.shape == (2, 2)
91-
"""
58+
def recompute_centroids(clusters):
9259
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster])
9360

94-
def quantum_kmeans(data: np.ndarray, k: int, max_iters: int = 10) -> tuple[np.ndarray, list[list[np.ndarray]]]:
95-
"""
96-
Applies the quantum k-means clustering algorithm.
97-
98-
:param data: The dataset to cluster.
99-
:param k: The number of clusters.
100-
:param max_iters: The maximum number of iterations.
101-
:return: A tuple containing final centroids and clusters.
102-
103-
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
104-
>>> centroids, clusters = quantum_kmeans(data, 2)
105-
>>> assert centroids.shape[0] == 2
106-
"""
61+
def quantum_kmeans(data, k, max_iters=10):
10762
centroids = initialize_centroids(data, k)
108-
63+
10964
for _ in range(max_iters):
11065
clusters = assign_clusters(data, centroids)
11166
new_centroids = recompute_centroids(clusters)
11267
if np.allclose(new_centroids, centroids):
11368
break
11469
centroids = new_centroids
115-
70+
11671
return centroids, clusters
11772

73+
11874
# Main execution
11975
n_samples, n_clusters = 10, 2
12076
data, labels = generate_data(n_samples, n_clusters=n_clusters)
@@ -130,12 +86,20 @@ def quantum_kmeans(data: np.ndarray, k: int, max_iters: int = 10) -> tuple[np.nd
13086
plt.subplot(122)
13187
for i, cluster in enumerate(final_clusters):
13288
cluster = np.array(cluster)
133-
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}')
134-
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids')
89+
plt.scatter(cluster[:, 0], cluster[:, 1], label=f"Cluster {i+1}")
90+
plt.scatter(
91+
final_centroids[:, 0],
92+
final_centroids[:, 1],
93+
color="red",
94+
marker="x",
95+
s=200,
96+
linewidths=3,
97+
label="Centroids",
98+
)
13599
plt.title("Quantum k-Means Clustering with Cirq")
136100
plt.legend()
137101

138102
plt.tight_layout()
139103
plt.show()
140104

141-
print(f"Final Centroids:\n{final_centroids}")
105+
print(f"Final Centroids:\n{final_centroids}")

0 commit comments

Comments
 (0)