Skip to content

Commit 3d4150a

Browse files
authored
Update adaptive_resonance_theory_1.py
1 parent 4b5fa39 commit 3d4150a

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

neural_network/adaptive_resonance_theory_1.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,82 +15,83 @@
1515

1616
import numpy as np
1717

18-
1918
class ART1:
2019
def __init__(self, num_features, vigilance=0.8):
2120
"""
2221
Initialize the ART1 model with the number of features and the vigilance parameter.
23-
22+
2423
Parameters:
2524
num_features (int): Number of features in input binary data.
2625
vigilance (float): Vigilance parameter to control cluster formation (0 < vigilance <= 1).
2726
"""
2827
self.num_features = num_features
2928
self.vigilance = vigilance
3029
self.weights = [] # Stores the weights for clusters
31-
30+
3231
def _similarity(self, x, w):
3332
"""
3433
Calculate similarity between input vector x and weight vector w.
35-
34+
3635
Parameters:
3736
x (np.array): Input binary vector.
3837
w (np.array): Cluster weight vector.
39-
38+
4039
Returns:
4140
float: Similarity value based on the intersection over the input length.
4241
"""
4342
return np.sum(np.minimum(x, w)) / np.sum(x)
44-
43+
4544
def _weight_update(self, x, w):
4645
"""
4746
Update weights for a cluster based on input vector.
48-
47+
4948
Parameters:
5049
x (np.array): Input binary vector.
5150
w (np.array): Cluster weight vector.
52-
51+
5352
Returns:
5453
np.array: Updated weight vector.
5554
"""
5655
return np.minimum(x, w)
57-
56+
5857
def train(self, data):
5958
"""
6059
Train the ART1 model to form clusters based on the vigilance parameter.
61-
60+
6261
Parameters:
6362
data (np.array): Binary dataset with each row as a sample.
6463
"""
6564
for x in data:
6665
assigned = False
6766
for i, w in enumerate(self.weights):
67+
# Split the line here to satisfy the line-length requirement
6868
if self._similarity(x, w) >= self.vigilance:
6969
self.weights[i] = self._weight_update(x, w)
7070
assigned = True
7171
break
7272
if not assigned:
7373
self.weights.append(x.copy())
74-
74+
7575
def predict(self, x):
7676
"""
7777
Predict the cluster for a new input vector or classify it as a new cluster.
78-
78+
7979
Parameters:
8080
x (np.array): Input binary vector.
81-
81+
8282
Returns:
8383
int: Cluster index for the input or -1 if classified as a new cluster.
8484
"""
8585
for i, w in enumerate(self.weights):
86+
# Split the line here to satisfy the line-length requirement
8687
if self._similarity(x, w) >= self.vigilance:
8788
return i
8889
return -1
89-
90+
9091
def get_weights(self):
9192
"""
9293
Retrieve the weight vectors of the clusters.
93-
94+
9495
Returns:
9596
list: List of weight vectors for each cluster.
9697
"""

0 commit comments

Comments
 (0)