Skip to content

Commit 115ac6b

Browse files
Update adaptive_resonance_theory.py
1 parent 1deb38b commit 115ac6b

File tree

1 file changed

+60
-55
lines changed

1 file changed

+60
-55
lines changed

neural_network/adaptive_resonance_theory.py

+60-55
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,36 @@
1-
import numpy as np
1+
"""
2+
adaptive_resonance_theory.py
3+
4+
This module implements the Adaptive Resonance Theory 1 (ART1) model, a type
5+
of neural network designed for unsupervised learning and clustering of binary
6+
input data. The ART1 algorithm continuously learns to categorize inputs based
7+
on their similarity while preserving previously learned categories. This is
8+
achieved through a vigilance parameter that controls the strictness of
9+
category matching, allowing for flexible and adaptive clustering.
10+
11+
ART1 is particularly useful in applications where it is critical to learn new
12+
patterns without forgetting previously learned ones, making it suitable for
13+
real-time data clustering and pattern recognition tasks.
14+
15+
References:
16+
1. Carpenter, G. A., & Grossberg, S. (1987). "Adaptive Resonance Theory."
17+
In: Neural Networks for Pattern Recognition, Oxford University Press, pp..
18+
2. Carpenter, G. A., & Grossberg, S. (1988). "The ART of Adaptive Pattern
19+
Recognition by a Self-Organizing Neural Network." IEEE Transactions on
20+
Neural Networks, 1(2) . DOI: 10.1109/TNN.1988.82656
21+
"""
222

23+
import numpy as np
324

425
class ART1:
526
"""
627
Adaptive Resonance Theory 1 (ART1) model for binary data clustering.
728
8-
This model is designed for unsupervised learning and clustering of binary
9-
input data. The ART1 algorithm continuously learns to categorize inputs based
10-
on their similarity while preserving previously learned categories. This is
11-
achieved through a vigilance parameter that controls the strictness of
12-
category matching, allowing for flexible and adaptive clustering.
13-
14-
ART1 is particularly useful in applications where it is critical to learn new
15-
patterns without forgetting previously learned ones, making it suitable for
16-
real-time data clustering and pattern recognition tasks.
17-
18-
References:
19-
1. Carpenter, G. A., & Grossberg, S. (1987). "A Adaptive Resonance Theory."
20-
In: Neural Networks for Pattern Recognition, Oxford University Press.
21-
2. Carpenter, G. A., & Grossberg, S. (1988). "The ART of Adaptive Pattern
22-
Recognition by a Self-Organizing Neural Network." IEEE Transactions on
23-
Neural Networks, 1(2). DOI: 10.1109/TNN.1988.82656
29+
Attributes:
30+
num_features (int): Number of features in the input data.
31+
vigilance (float): Threshold for similarity that determines whether
32+
an input matches an existing cluster.
33+
weights (list): List of cluster weights representing the learned categories.
2434
"""
2535

2636
def __init__(self, num_features: int, vigilance: float = 0.7) -> None:
@@ -32,7 +42,7 @@ def __init__(self, num_features: int, vigilance: float = 0.7) -> None:
3242
vigilance (float): Threshold for similarity (default is 0.7).
3343
3444
Raises:
35-
ValueError: If num_features not positive or vigilance not between 0 and 1.
45+
ValueError: If num_features is not positive or vigilance is not between 0 and 1.
3646
"""
3747
if num_features <= 0:
3848
raise ValueError("Number of features must be a positive integer.")
@@ -54,69 +64,64 @@ def _similarity(self, weight_vector: np.ndarray, input_vector: np.ndarray) -> fl
5464
Returns:
5565
float: The similarity score between the weight and the input.
5666
"""
57-
if (
58-
len(weight_vector) != self.num_features
59-
or len(input_vector) != self.num_features
60-
):
61-
raise ValueError(
62-
"Both weight_vector and input_vector must have the same number."
63-
)
67+
if len(weight_vector) != self.num_features or len(input_vector) != self.num_features:
68+
raise ValueError("Both weight_vector and input_vector must have the same number of features.")
6469

6570
return np.dot(weight_vector, input_vector) / self.num_features
6671

67-
def _learn(
68-
self, w: np.ndarray, x: np.ndarray, learning_rate: float = 0.5
69-
) -> np.ndarray:
72+
def _learn(self, current_weights: np.ndarray, input_vector: np.ndarray, learning_rate: float = 0.5) -> np.ndarray:
7073
"""
7174
Update cluster weights using the learning rate.
7275
7376
Args:
74-
w (np.ndarray): Current weight vector for the cluster.
75-
x (np.ndarray): Input vector.
77+
current_weights (np.ndarray): Current weight vector for the cluster.
78+
input_vector (np.ndarray): Input vector.
7679
learning_rate (float): Learning rate for weight update (default is 0.5).
7780
7881
Returns:
7982
np.ndarray: Updated weight vector.
83+
"""
84+
return learning_rate * input_vector + (1 - learning_rate) * current_weights
8085

81-
Examples:
82-
>>> model = ART1(num_features=4)
83-
>>> w = np.array([1, 1, 0, 0])
84-
>>> x = np.array([0, 1, 1, 0])
85-
>>> model._learn(w, x)
86-
array([0.5, 1. , 0.5, 0. ])
86+
def train(self, input_data: np.ndarray) -> None:
8787
"""
88-
return learning_rate * x + (1 - learning_rate) * w
88+
Train the ART1 model on the provided input data.
89+
90+
Args:
91+
input_data (np.ndarray): Array of input vectors to train on.
8992
90-
def predict(self, x: np.ndarray) -> int:
93+
Returns:
94+
None
95+
"""
96+
for input_vector in input_data:
97+
assigned_cluster_index = self.predict(input_vector)
98+
if assigned_cluster_index == -1:
99+
# No matching cluster, create a new one
100+
self.weights.append(input_vector)
101+
else:
102+
# Update the weights of the assigned cluster
103+
self.weights[assigned_cluster_index] = self._learn(self.weights[assigned_cluster_index], input_vector)
104+
105+
def predict(self, input_vector: np.ndarray) -> int:
91106
"""
92107
Assign data to the closest cluster.
93108
94109
Args:
95-
x (np.ndarray): Input vector.
110+
input_vector (np.ndarray): Input vector.
96111
97112
Returns:
98113
int: Index of the assigned cluster, or -1 if no match.
99-
100-
Examples:
101-
>>> model = ART1(num_features=4)
102-
>>> model.weights = [np.array([1, 1, 0, 0])]
103-
>>> model.predict(np.array([1, 1, 0, 0]))
104-
0
105-
>>> model.predict(np.array([0, 0, 0, 0]))
106-
-1
107114
"""
108-
similarities = [self._similarity(w, x) for w in self.weights]
109-
return (
110-
np.argmax(similarities) if max(similarities) >= self.vigilance else -1
111-
) # -1 if no match
115+
similarities = [self._similarity(weight, input_vector) for weight in self.weights]
116+
return np.argmax(similarities) if max(similarities) >= self.vigilance else -1 # -1 if no match
112117

113118

114119
# Example usage for ART1
115120
def art1_example() -> None:
116121
"""
117122
Example function demonstrating the usage of the ART1 model.
118123
119-
This function creates dataset, trains ART1 model, and prints assigned clusters.
124+
This function creates a dataset, trains the ART1 model, and prints assigned clusters.
120125
121126
Examples:
122127
>>> art1_example()
@@ -127,10 +132,10 @@ def art1_example() -> None:
127132
"""
128133
data = np.array([[1, 1, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1], [0, 1, 0, 1]])
129134
model = ART1(num_features=4, vigilance=0.5)
130-
# model.train(data) # Ensure this method is defined in ART1
135+
model.train(data)
131136

132-
for i, x in enumerate(data):
133-
cluster = model.predict(x)
137+
for i, input_vector in enumerate(data):
138+
cluster = model.predict(input_vector)
134139
print(f"Data point {i} assigned to cluster: {cluster}")
135140

136141

0 commit comments

Comments
 (0)