-
-
Notifications
You must be signed in to change notification settings - Fork 46.9k
Add Quantum k-Means Clustering Implementation #11664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7d1d891
027549b
e2d0e50
89f5f80
c46141d
caf89c4
facfce2
3655742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import cirq | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from sklearn.datasets import make_blobs | ||
from sklearn.preprocessing import MinMaxScaler | ||
|
||
def generate_data(n_samples=100, n_features=2, n_clusters=2): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data, labels = make_blobs(n_samples=n_samples, centers=n_clusters, n_features=n_features, random_state=42) | ||
return MinMaxScaler().fit_transform(data), labels | ||
|
||
def quantum_distance(point1, point2): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Quantum circuit explanation: | ||
1. Use a single qubit to encode the distance between two points. | ||
2. Apply Ry rotation based on the normalized Euclidean distance. | ||
3. Measure the qubit to get a probabilistic distance metric. | ||
The probability of measuring |1> correlates with the distance between points. | ||
""" | ||
qubit = cirq.LineQubit(0) | ||
diff = np.clip(np.linalg.norm(point1 - point2), 0, 1) | ||
theta = 2 * np.arcsin(diff) | ||
|
||
circuit = cirq.Circuit( | ||
cirq.ry(theta)(qubit), | ||
cirq.measure(qubit, key='result') | ||
) | ||
|
||
result = cirq.Simulator().run(circuit, repetitions=1000) | ||
return result.histogram(key='result').get(1, 0) / 1000 | ||
|
||
def initialize_centroids(data, k): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide return type hint for the function: As there is no test file in this pull request nor any test function or class in the file Please provide type hint for the parameter: Please provide descriptive name for the parameter: Please provide type hint for the parameter: |
||
return data[np.random.choice(len(data), k, replace=False)] | ||
|
||
def assign_clusters(data, centroids): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
clusters = [[] for _ in range(len(centroids))] | ||
for point in data: | ||
closest = min(range(len(centroids)), key=lambda i: quantum_distance(point, centroids[i])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide descriptive name for the parameter: |
||
clusters[closest].append(point) | ||
return clusters | ||
|
||
def recompute_centroids(clusters): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return np.array([np.mean(cluster, axis=0) for cluster in clusters if cluster]) | ||
|
||
def quantum_kmeans(data, k, max_iters=10): | ||
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
RahulPatnaik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
centroids = initialize_centroids(data, k) | ||
|
||
for _ in range(max_iters): | ||
clusters = assign_clusters(data, centroids) | ||
new_centroids = recompute_centroids(clusters) | ||
if np.allclose(new_centroids, centroids): | ||
break | ||
centroids = new_centroids | ||
|
||
return centroids, clusters | ||
|
||
# Main execution | ||
n_samples, n_clusters = 10, 2 | ||
data, labels = generate_data(n_samples, n_clusters=n_clusters) | ||
|
||
plt.figure(figsize=(12, 5)) | ||
|
||
plt.subplot(121) | ||
plt.scatter(data[:, 0], data[:, 1], c=labels) | ||
plt.title("Generated Data") | ||
|
||
final_centroids, final_clusters = quantum_kmeans(data, n_clusters) | ||
|
||
plt.subplot(122) | ||
for i, cluster in enumerate(final_clusters): | ||
cluster = np.array(cluster) | ||
plt.scatter(cluster[:, 0], cluster[:, 1], label=f'Cluster {i+1}') | ||
plt.scatter(final_centroids[:, 0], final_centroids[:, 1], color='red', marker='x', s=200, linewidths=3, label='Centroids') | ||
plt.title("Quantum k-Means Clustering with Cirq") | ||
plt.legend() | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
print(f"Final Centroids:\n{final_centroids}") | ||
Uh oh!
There was an error while loading. Please reload this page.