15
15
16
16
import numpy as np
17
17
18
+
18
19
class ART1 :
19
20
def __init__ (self , num_features , vigilance = 0.8 ):
20
21
"""
21
22
Initialize the ART1 model with the number of features and the vigilance parameter.
22
-
23
+
23
24
Parameters:
24
25
num_features (int): Number of features in input binary data.
25
26
vigilance (float): Vigilance parameter to control cluster formation (0 < vigilance <= 1).
26
27
"""
27
28
self .num_features = num_features
28
29
self .vigilance = vigilance
29
30
self .weights = [] # Stores the weights for clusters
30
-
31
+
31
32
def _similarity (self , x , w ):
32
33
"""
33
34
Calculate similarity between input vector x and weight vector w.
34
-
35
+
35
36
Parameters:
36
37
x (np.array): Input binary vector.
37
38
w (np.array): Cluster weight vector.
38
-
39
+
39
40
Returns:
40
41
float: Similarity value based on the intersection over the input length.
41
42
"""
42
43
return np .sum (np .minimum (x , w )) / np .sum (x )
43
-
44
+
44
45
def _weight_update (self , x , w ):
45
46
"""
46
47
Update weights for a cluster based on input vector.
47
-
48
+
48
49
Parameters:
49
50
x (np.array): Input binary vector.
50
51
w (np.array): Cluster weight vector.
51
-
52
+
52
53
Returns:
53
54
np.array: Updated weight vector.
54
55
"""
55
56
return np .minimum (x , w )
56
-
57
+
57
58
def train (self , data ):
58
59
"""
59
60
Train the ART1 model to form clusters based on the vigilance parameter.
60
-
61
+
61
62
Parameters:
62
63
data (np.array): Binary dataset with each row as a sample.
63
64
"""
@@ -71,14 +72,14 @@ def train(self, data):
71
72
break
72
73
if not assigned :
73
74
self .weights .append (x .copy ())
74
-
75
+
75
76
def predict (self , x ):
76
77
"""
77
78
Predict the cluster for a new input vector or classify it as a new cluster.
78
-
79
+
79
80
Parameters:
80
81
x (np.array): Input binary vector.
81
-
82
+
82
83
Returns:
83
84
int: Cluster index for the input or -1 if classified as a new cluster.
84
85
"""
@@ -87,11 +88,11 @@ def predict(self, x):
87
88
if self ._similarity (x , w ) >= self .vigilance :
88
89
return i
89
90
return - 1
90
-
91
+
91
92
def get_weights (self ):
92
93
"""
93
94
Retrieve the weight vectors of the clusters.
94
-
95
+
95
96
Returns:
96
97
list: List of weight vectors for each cluster.
97
98
"""
0 commit comments