Skip to content

Commit c47d8e3

Browse files
adjust sorting libraries radial_basis_function_neural_network.py
1 parent ae74131 commit c47d8e3

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

neural_network/radial_basis_function_neural_network.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22

3-
43
class RadialBasisFunctionNeuralNetwork:
54
"""
65
A simple implementation of a Radial Basis Function Neural Network (RBFNN).
@@ -32,11 +31,11 @@ def __init__(self, num_centers: int, spread: float) -> None:
3231

3332
def _gaussian_rbf(self, input_vector: np.ndarray, center: np.ndarray) -> float:
3433
"""
35-
Calculate the Gaussian radial basis function output for a given input vector and center.
34+
Calculate Gaussian radial basis function output for input vector and center.
3635
3736
Args:
38-
input_vector (np.ndarray): The input vector for which to calculate the RBF output.
39-
center (np.ndarray): The center of the radial basis function.
37+
input_vector (np.ndarray): Input vector for which to calculate the RBF output.
38+
center (np.ndarray): Center of the radial basis function.
4039
4140
Returns:
4241
float: The output of the radial basis function evaluated at the input vector.
@@ -48,7 +47,7 @@ def _gaussian_rbf(self, input_vector: np.ndarray, center: np.ndarray) -> float:
4847
0.1353352832366127
4948
"""
5049
return np.exp(
51-
-(np.linalg.norm(input_vector - center) ** 2) / (2 * self.spread**2)
50+
-(np.linalg.norm(input_vector - center) ** 2) / (2 * self.spread ** 2)
5251
)
5352

5453
def _compute_rbf_outputs(self, input_data: np.ndarray) -> np.ndarray:
@@ -59,7 +58,7 @@ def _compute_rbf_outputs(self, input_data: np.ndarray) -> np.ndarray:
5958
input_data (np.ndarray): Input data matrix (num_samples x num_features).
6059
6160
Returns:
62-
np.ndarray: A matrix of shape (num_samples x num_centers) containing the RBF outputs.
61+
np.ndarray: A matrix of shape (num_samples x num_centers) with RBF outputs.
6362
6463
Examples:
6564
>>> rbf_nn = RadialBasisFunctionNeuralNetwork(num_centers=2, spread=1.0)
@@ -83,7 +82,7 @@ def fit(self, input_data: np.ndarray, target_values: np.ndarray) -> None:
8382
target_values (np.ndarray): Target values (num_samples x output_dim).
8483
8584
Raises:
86-
ValueError: If the number of samples in input_data and target_values do not match.
85+
ValueError: If number of samples in input_data and target_values not match.
8786
8887
Examples:
8988
>>> rbf_nn = RadialBasisFunctionNeuralNetwork(num_centers=2, spread=1.0)
@@ -136,7 +135,6 @@ def predict(self, input_data: np.ndarray) -> np.ndarray:
136135
# Example Usage
137136
if __name__ == "__main__":
138137
import doctest
139-
140138
doctest.testmod()
141139

142140
# Sample dataset for XOR problem

0 commit comments

Comments
 (0)