Skip to content

Commit 03eed71

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ca15518 commit 03eed71

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

neural_network/radial_basis_function_neural_network.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
34
class RadialBasisFunctionNeuralNetwork:
45
"""
56
A simple implementation of a Radial Basis Function Neural Network (RBFNN).
@@ -46,7 +47,9 @@ def _gaussian_rbf(self, input_vector: np.ndarray, center: np.ndarray) -> float:
4647
>>> rbf_nn._gaussian_rbf(np.array([0, 0]), center)
4748
0.1353352832366127
4849
"""
49-
return np.exp(-(np.linalg.norm(input_vector - center) ** 2) / (2 * self.spread ** 2))
50+
return np.exp(
51+
-(np.linalg.norm(input_vector - center) ** 2) / (2 * self.spread**2)
52+
)
5053

5154
def _compute_rbf_outputs(self, input_data: np.ndarray) -> np.ndarray:
5255
"""
@@ -91,10 +94,14 @@ def fit(self, input_data: np.ndarray, target_values: np.ndarray):
9194
True
9295
"""
9396
if input_data.shape[0] != target_values.shape[0]:
94-
raise ValueError("Number of samples in input_data and target_values must match.")
97+
raise ValueError(
98+
"Number of samples in input_data and target_values must match."
99+
)
95100

96101
# Initialize centers using random samples from input_data
97-
random_indices = np.random.choice(input_data.shape[0], self.num_centers, replace=False)
102+
random_indices = np.random.choice(
103+
input_data.shape[0], self.num_centers, replace=False
104+
)
98105
self.centers = input_data[random_indices]
99106

100107
# Compute the RBF outputs for the training data
@@ -128,6 +135,7 @@ def predict(self, input_data: np.ndarray) -> np.ndarray:
128135
# Example Usage
129136
if __name__ == "__main__":
130137
import doctest
138+
131139
doctest.testmod()
132140

133141
# Sample dataset for XOR problem

0 commit comments

Comments
 (0)