Skip to content

Commit caf89c4

Browse files
committed
Made relevant changes specified
1 parent 89f5f80 commit caf89c4

File tree

1 file changed

+63
-27
lines changed

1 file changed

+63
-27
lines changed

quantum/quantum_kmeans_clustering.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,23 @@
44
from sklearn.datasets import make_blobs
55
from sklearn.preprocessing import MinMaxScaler
66

7+
def generate_data(n_samples: int = 100, n_features: int = 2, n_clusters: int = 2) -> tuple[np.ndarray, np.ndarray]:
8+
"""
9+
Generates synthetic data using the make_blobs function and normalizes it.
10+
11+
:param n_samples: Number of samples to generate.
12+
:param n_features: Number of features for each sample.
13+
:param n_clusters: Number of clusters to generate.
14+
:return: A tuple containing normalized data and labels.
715
8-
def generate_data(n_samples=100, n_features=2, n_clusters=2):
16+
>>> data, labels = generate_data(10, 2, 2)
17+
>>> assert data.shape == (10, 2)
18+
>>> assert len(labels) == 10
19+
"""
920
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42)
1021
return MinMaxScaler().fit_transform(data), labels
1122

12-
def quantum_distance(point1, point2):
23+
def quantum_distance(point1: np.ndarray, point2: np.ndarray) -> float:
1324
"""
1425
Computes the quantum distance between two points.
1526
@@ -25,12 +36,14 @@ def quantum_distance(point1, point2):
2536
qubit = cirq.LineQubit(0)
2637
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1)
2738
theta = 2 * np.arcsin(diff)
28-
29-
circuit = cirq.Circuit(cirq.ry(theta)(qubit), cirq.measure(qubit, key="result"))
30-
39+
40+
circuit = cirq.Circuit(
41+
cirq.ry(theta)(qubit),
42+
cirq.measure(qubit, key='result')
43+
)
44+
3145
result = cirq.Simulator().run(circuit, repetitions=1000)
32-
return result.histogram(key="result").get(1, 0) / 1000
33-
46+
return result.histogram(key='result').get(1, 0) / 1000
3447

3548
def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
3649
"""
@@ -46,31 +59,62 @@ def initialize_centroids(data: np.ndarray, k: int) -> np.ndarray:
4659
"""
4760
return data[np.random.choice(len(data), k, replace=False)]
4861

49-
def assign_clusters(data, centroids):
62+
def assign_clusters(data: np.ndarray, centroids: np.ndarray) -> list[list[np.ndarray]]:
63+
"""
64+
Assigns data points to the nearest centroid.
65+
66+
:param data: The dataset to cluster.
67+
:param centroids: The current centroids.
68+
:return: A list of clusters, each containing points assigned to it.
69+
70+
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
71+
>>> centroids = np.array([[1, 2], [5, 6]])
72+
>>> clusters = assign_clusters(data, centroids)
73+
>>> assert len(clusters) == 2
74+
"""
5075
clusters = [[] for _ in range(len(centroids))]
5176
for point in data:
52-
closest = min(
53-
range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i])
54-
)
77+
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i]))
5578
clusters[closest].append(point)
5679
return clusters
5780

58-
def recompute_centroids(clusters):
81+
def recompute_centroids(clusters: list[list[np.ndarray]]) -> np.ndarray:
82+
"""
83+
Recomputes the centroids based on the assigned clusters.
84+
85+
:param clusters: A list of clusters, each containing points assigned to it.
86+
:return: An array of newly computed centroids.
87+
88+
>>> clusters = [[np.array([1, 2]), np.array([1, 3])], [np.array([5, 6]), np.array([5, 7])]]
89+
>>> centroids = recompute_centroids(clusters)
90+
>>> assert centroids.shape == (2, 2)
91+
"""
5992
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster])
6093

61-
def quantum_kmeans(data, k, max_iters=10):
62-
centroids = initialize_centroids(data, k)
94+
def quantum_kmeans(data: np.ndarray, k: int, max_iters: int = 10) -> tuple[np.ndarray, list[list[np.ndarray]]]:
95+
"""
96+
Applies the quantum k-means clustering algorithm.
6397
98+
:param data: The dataset to cluster.
99+
:param k: The number of clusters.
100+
:param max_iters: The maximum number of iterations.
101+
:return: A tuple containing final centroids and clusters.
102+
103+
>>> data = np.array([[1, 2], [3, 4], [5, 6]])
104+
>>> centroids, clusters = quantum_kmeans(data, 2)
105+
>>> assert centroids.shape[0] == 2
106+
"""
107+
centroids = initialize_centroids(data, k)
108+
64109
for _ in range(max_iters):
65110
clusters = assign_clusters(data, centroids)
66111
new_centroids = recompute_centroids(clusters)
67112
if np.allclose(new_centroids, centroids):
68113
break
69114
centroids = new_centroids
70-
115+
71116
return centroids, clusters
72117

73-
74118
# Main execution
75119
n_samples, n_clusters = 10, 2
76120
data, labels = generate_data(n_samples, n_clusters=n_clusters)
@@ -86,20 +130,12 @@ def quantum_kmeans(data, k, max_iters=10):
86130
plt.subplot(122)
87131
for i, cluster in enumerate(final_clusters):
88132
cluster = np.array(cluster)
89-
plt.scatter(cluster[:, 0], cluster[:, 1], label=f"Cluster {i+1}")
90-
plt.scatter(
91-
final_centroids[:, 0],
92-
final_centroids[:, 1],
93-
color="red",
94-
marker="x",
95-
s=200,
96-
linewidths=3,
97-
label="Centroids",
98-
)
133+
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}')
134+
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids')
99135
plt.title("Quantum k-Means Clustering with Cirq")
100136
plt.legend()
101137

102138
plt.tight_layout()
103139
plt.show()
104140

105-
print(f"Final Centroids:\n{final_centroids}")
141+
print(f"Final Centroids:\n{final_centroids}")

0 commit comments

Comments
 (0)