Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a65dcc2

Browse files
committedOct 31, 2024·
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3d4150a commit a65dcc2

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed
 

‎neural_network/adaptive_resonance_theory_1.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,50 @@
1515

1616
import numpy as np
1717

18+
1819
class ART1:
1920
def __init__(self, num_features, vigilance=0.8):
2021
"""
2122
Initialize the ART1 model with the number of features and the vigilance parameter.
22-
23+
2324
Parameters:
2425
num_features (int): Number of features in input binary data.
2526
vigilance (float): Vigilance parameter to control cluster formation (0 < vigilance <= 1).
2627
"""
2728
self.num_features = num_features
2829
self.vigilance = vigilance
2930
self.weights = [] # Stores the weights for clusters
30-
31+
3132
def _similarity(self, x, w):
3233
"""
3334
Calculate similarity between input vector x and weight vector w.
34-
35+
3536
Parameters:
3637
x (np.array): Input binary vector.
3738
w (np.array): Cluster weight vector.
38-
39+
3940
Returns:
4041
float: Similarity value based on the intersection over the input length.
4142
"""
4243
return np.sum(np.minimum(x, w)) / np.sum(x)
43-
44+
4445
def _weight_update(self, x, w):
4546
"""
4647
Update weights for a cluster based on input vector.
47-
48+
4849
Parameters:
4950
x (np.array): Input binary vector.
5051
w (np.array): Cluster weight vector.
51-
52+
5253
Returns:
5354
np.array: Updated weight vector.
5455
"""
5556
return np.minimum(x, w)
56-
57+
5758
def train(self, data):
5859
"""
5960
Train the ART1 model to form clusters based on the vigilance parameter.
60-
61+
6162
Parameters:
6263
data (np.array): Binary dataset with each row as a sample.
6364
"""
@@ -71,14 +72,14 @@ def train(self, data):
7172
break
7273
if not assigned:
7374
self.weights.append(x.copy())
74-
75+
7576
def predict(self, x):
7677
"""
7778
Predict the cluster for a new input vector or classify it as a new cluster.
78-
79+
7980
Parameters:
8081
x (np.array): Input binary vector.
81-
82+
8283
Returns:
8384
int: Cluster index for the input or -1 if classified as a new cluster.
8485
"""
@@ -87,11 +88,11 @@ def predict(self, x):
8788
if self._similarity(x, w) >= self.vigilance:
8889
return i
8990
return -1
90-
91+
9192
def get_weights(self):
9293
"""
9394
Retrieve the weight vectors of the clusters.
94-
95+
9596
Returns:
9697
list: List of weight vectors for each cluster.
9798
"""

0 commit comments

Comments
 (0)
Please sign in to comment.