Skip to content

Commit fc3d7dd

Browse files
Update adaptive_resonance_theory.py
1 parent 241981a commit fc3d7dd

File tree

1 file changed

+18
-51
lines changed

1 file changed

+18
-51
lines changed

neural_network/adaptive_resonance_theory.py

+18-51
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,7 @@ class ART1:
55
"""
66
Adaptive Resonance Theory 1 (ART1) model for binary data clustering.
77
8-
The ART1 algorithm is a type of neural network used for unsupervised
9-
learning and clustering of binary input data. It continuously learns
10-
to categorize inputs based on similarity while preserving previously
11-
learned categories. The vigilance parameter controls the degree of
12-
similarity required to assign an input to an existing category,
13-
allowing for flexible and adaptive clustering.
8+
...
149
1510
Attributes:
1611
num_features (int): Number of features in the input data.
@@ -26,62 +21,34 @@ def __init__(self, num_features: int, vigilance: float = 0.7) -> None:
2621
Args:
2722
num_features (int): Number of features in the input data.
2823
vigilance (float): Threshold for similarity (default is 0.7).
29-
30-
Examples:
31-
>>> model = ART1(num_features=4, vigilance=0.5)
32-
>>> model.num_features
33-
4
34-
>>> model.vigilance
35-
0.5
24+
25+
Raises:
26+
ValueError: If num_features is not positive or if vigilance is not between 0 and 1.
3627
"""
37-
self.vigilance = vigilance # Controls cluster strictness
28+
if num_features <= 0:
29+
raise ValueError("Number of features must be a positive integer.")
30+
if not (0 <= vigilance <= 1):
31+
raise ValueError("Vigilance parameter must be between 0 and 1.")
32+
33+
self.vigilance = vigilance
3834
self.num_features = num_features
39-
self.weights = [] # List of cluster weights
40-
41-
def train(self, data: np.ndarray) -> None:
42-
"""
43-
Train the ART1 model on the provided data.
35+
self.weights = []
4436

45-
Args:
46-
data (np.ndarray): A 2D array of binary input data (num_samples x num_features).
47-
48-
Examples:
49-
>>> model = ART1(num_features=4, vigilance=0.5)
50-
>>> data = np.array([[1, 1, 0, 0], [1, 1, 1, 0]])
51-
>>> model.train(data)
52-
>>> len(model.weights)
53-
2
54-
"""
55-
for x in data:
56-
match = False
57-
for i, w in enumerate(self.weights):
58-
if self._similarity(w, x) >= self.vigilance:
59-
self.weights[i] = self._learn(w, x)
60-
match = True
61-
break
62-
if not match:
63-
self.weights.append(x.copy()) # Add a new cluster
64-
65-
def _similarity(self, w: np.ndarray, x: np.ndarray) -> float:
37+
def _similarity(self, weight_vector: np.ndarray, input_vector: np.ndarray) -> float:
6638
"""
6739
Calculate similarity between weight and input.
6840
6941
Args:
70-
w (np.ndarray): Weight vector representing a cluster.
71-
x (np.ndarray): Input vector.
42+
weight_vector (np.ndarray): Weight vector representing a cluster.
43+
input_vector (np.ndarray): Input vector.
7244
7345
Returns:
7446
float: The similarity score between the weight and the input.
75-
76-
Examples:
77-
>>> model = ART1(num_features=4)
78-
>>> w = np.array([1, 1, 0, 0])
79-
>>> x = np.array([1, 0, 0, 0])
80-
>>> model._similarity(w, x)
81-
0.25
8247
"""
83-
return np.dot(w, x) / (self.num_features)
84-
48+
if len(weight_vector) != self.num_features or len(input_vector) != self.num_features:
49+
raise ValueError(f"Both weight_vector and input_vector must have {self.num_features} features.")
50+
51+
return np.dot(weight_vector, input_vector) / self.num_features
8552
def _learn(
8653
self, w: np.ndarray, x: np.ndarray, learning_rate: float = 0.5
8754
) -> np.ndarray:

0 commit comments

Comments
 (0)