Skip to content

Commit 4b5fa39

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6d8e719 commit 4b5fa39

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

neural_network/adaptive_resonance_theory_1.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
2+
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
33
Name - - ART1 - Adaptive Resonance Theory 1
44
Goal - - Cluster Binary Data
55
Detail: Unsupervised clustering model using a vigilance parameter
@@ -10,53 +10,55 @@
1010
Author: Your Name
1111
1212
Date: 2024.10.31
13-
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
13+
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
1414
"""
15+
1516
import numpy as np
1617

18+
1719
class ART1:
1820
def __init__(self, num_features, vigilance=0.8):
1921
"""
2022
Initialize the ART1 model with the number of features and the vigilance parameter.
21-
23+
2224
Parameters:
2325
num_features (int): Number of features in input binary data.
2426
vigilance (float): Vigilance parameter to control cluster formation (0 < vigilance <= 1).
2527
"""
2628
self.num_features = num_features
2729
self.vigilance = vigilance
2830
self.weights = [] # Stores the weights for clusters
29-
31+
3032
def _similarity(self, x, w):
3133
"""
3234
Calculate similarity between input vector x and weight vector w.
33-
35+
3436
Parameters:
3537
x (np.array): Input binary vector.
3638
w (np.array): Cluster weight vector.
37-
39+
3840
Returns:
3941
float: Similarity value based on the intersection over the input length.
4042
"""
4143
return np.sum(np.minimum(x, w)) / np.sum(x)
42-
44+
4345
def _weight_update(self, x, w):
4446
"""
4547
Update weights for a cluster based on input vector.
46-
48+
4749
Parameters:
4850
x (np.array): Input binary vector.
4951
w (np.array): Cluster weight vector.
50-
52+
5153
Returns:
5254
np.array: Updated weight vector.
5355
"""
5456
return np.minimum(x, w)
55-
57+
5658
def train(self, data):
5759
"""
5860
Train the ART1 model to form clusters based on the vigilance parameter.
59-
61+
6062
Parameters:
6163
data (np.array): Binary dataset with each row as a sample.
6264
"""
@@ -69,26 +71,26 @@ def train(self, data):
6971
break
7072
if not assigned:
7173
self.weights.append(x.copy())
72-
74+
7375
def predict(self, x):
7476
"""
7577
Predict the cluster for a new input vector or classify it as a new cluster.
76-
78+
7779
Parameters:
7880
x (np.array): Input binary vector.
79-
81+
8082
Returns:
8183
int: Cluster index for the input or -1 if classified as a new cluster.
8284
"""
8385
for i, w in enumerate(self.weights):
8486
if self._similarity(x, w) >= self.vigilance:
8587
return i
8688
return -1
87-
89+
8890
def get_weights(self):
8991
"""
9092
Retrieve the weight vectors of the clusters.
91-
93+
9294
Returns:
9395
list: List of weight vectors for each cluster.
9496
"""

0 commit comments

Comments
 (0)