Skip to content

Commit 24666d3

Browse files
Adding a 3D plot to the k-means clustering algorithm
1 parent e3f3d66 commit 24666d3

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

machine_learning/k_means_clust.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@
3737
heterogeneity,
3838
k
3939
)
40-
5. Transfers Dataframe into excel format it must have feature called
40+
5. 3D Plot of the labeled data points with centroids.
41+
plot_kmeans(
42+
X,
43+
centroids,
44+
cluster_assignment
45+
)
46+
6. Transfers Dataframe into excel format it must have feature called
4147
'Clust' with k means clustering numbers in it.
4248
"""
4349

@@ -126,6 +132,19 @@ def plot_heterogeneity(heterogeneity, k):
126132
plt.show()
127133

128134

135+
def plot_kmeans(data, centroids, cluster_assignment):
136+
ax = plt.axes(projection="3d")
137+
ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=cluster_assignment, cmap="viridis")
138+
ax.scatter(
139+
centroids[:, 0], centroids[:, 1], centroids[:, 2], c="red", s=100, marker="x"
140+
)
141+
ax.set_xlabel("X")
142+
ax.set_ylabel("Y")
143+
ax.set_zlabel("Z")
144+
ax.set_title("3D K-Means Clustering Visualization")
145+
plt.show()
146+
147+
129148
def kmeans(
130149
data, k, initial_centroids, maxiter=500, record_heterogeneity=None, verbose=False
131150
):
@@ -193,6 +212,7 @@ def kmeans(
193212
verbose=True,
194213
)
195214
plot_heterogeneity(heterogeneity, k)
215+
plot_kmeans(dataset["data"], centroids, cluster_assignment)
196216

197217

198218
def report_generator(

0 commit comments

Comments
 (0)