Skip to content

Commit 241981a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d9a0134 commit 241981a

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

neural_network/adaptive_resonance_theory.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
import numpy as np
22

3+
34
class ART1:
45
"""
56
Adaptive Resonance Theory 1 (ART1) model for binary data clustering.
67
7-
The ART1 algorithm is a type of neural network used for unsupervised
8-
learning and clustering of binary input data. It continuously learns
9-
to categorize inputs based on similarity while preserving previously
10-
learned categories. The vigilance parameter controls the degree of
8+
The ART1 algorithm is a type of neural network used for unsupervised
9+
learning and clustering of binary input data. It continuously learns
10+
to categorize inputs based on similarity while preserving previously
11+
learned categories. The vigilance parameter controls the degree of
1112
similarity required to assign an input to an existing category,
1213
allowing for flexible and adaptive clustering.
1314
1415
Attributes:
1516
num_features (int): Number of features in the input data.
16-
vigilance (float): Threshold for similarity that determines whether
17+
vigilance (float): Threshold for similarity that determines whether
1718
an input matches an existing cluster.
1819
weights (list): List of cluster weights representing the learned categories.
1920
"""
20-
21+
2122
def __init__(self, num_features: int, vigilance: float = 0.7) -> None:
2223
"""
2324
Initialize the ART1 model with the given number of features and vigilance parameter.
2425
2526
Args:
2627
num_features (int): Number of features in the input data.
2728
vigilance (float): Threshold for similarity (default is 0.7).
28-
29+
2930
Examples:
3031
>>> model = ART1(num_features=4, vigilance=0.5)
3132
>>> model.num_features
@@ -35,8 +36,8 @@ def __init__(self, num_features: int, vigilance: float = 0.7) -> None:
3536
"""
3637
self.vigilance = vigilance # Controls cluster strictness
3738
self.num_features = num_features
38-
self.weights = [] # List of cluster weights
39-
39+
self.weights = [] # List of cluster weights
40+
4041
def train(self, data: np.ndarray) -> None:
4142
"""
4243
Train the ART1 model on the provided data.
@@ -80,8 +81,10 @@ def _similarity(self, w: np.ndarray, x: np.ndarray) -> float:
8081
0.25
8182
"""
8283
return np.dot(w, x) / (self.num_features)
83-
84-
def _learn(self, w: np.ndarray, x: np.ndarray, learning_rate: float = 0.5) -> np.ndarray:
84+
85+
def _learn(
86+
self, w: np.ndarray, x: np.ndarray, learning_rate: float = 0.5
87+
) -> np.ndarray:
8588
"""
8689
Update cluster weights using the learning rate.
8790
@@ -121,7 +124,9 @@ def predict(self, x: np.ndarray) -> int:
121124
-1
122125
"""
123126
similarities = [self._similarity(w, x) for w in self.weights]
124-
return np.argmax(similarities) if max(similarities) >= self.vigilance else -1 # -1 if no match
127+
return (
128+
np.argmax(similarities) if max(similarities) >= self.vigilance else -1
129+
) # -1 if no match
125130

126131

127132
# Example usage for ART1
@@ -146,5 +151,6 @@ def art1_example() -> None:
146151
cluster = model.predict(x)
147152
print(f"Data point {i} assigned to cluster: {cluster}")
148153

154+
149155
if __name__ == "__main__":
150156
art1_example()

0 commit comments

Comments
 (0)