Skip to content

Commit 49272c7

Browse files
Update radial_basis_function_neural_network.py
1 parent 52e01d1 commit 49272c7

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

neural_network/radial_basis_function_neural_network.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
"""
22
Radial Basis Function Neural Network (RBFNN)
33
4-
A Radial Basis Function Neural Network (RBFNN) is a type of artificial
5-
neural network that uses radial basis functions as activation functions.
6-
RBFNNs are particularly effective for function approximation, regression,
7-
and classification tasks.
4+
A Radial Basis Function Neural Network (RBFNN) is a type of artificial neural
5+
network that uses radial basis functions as activation functions.
6+
RBFNNs are particularly effective for function approximation, regression, and
7+
classification tasks. The architecture typically consists of an input layer,
8+
a hidden layer with radial basis functions, and an output layer.
9+
10+
In an RBFNN:
11+
- The hidden layer applies a radial basis function (often Gaussian) to the
12+
input data, transforming it into a higher-dimensional space.
13+
- The output layer combines the results from the hidden layer using
14+
weighted sums to produce the final output.
815
916
#### Reference
10-
1117
- Wikipedia: https://en.wikipedia.org/wiki/Radial_basis_function_network
1218
"""
1319

@@ -34,7 +40,7 @@ def __init__(self, num_centers: int, spread: float) -> None:
3440
spread (float): Spread of the radial basis functions.
3541
3642
Examples:
37-
>>> rbf_nn = RadialBasisFunctionNeuralNetwork(num_centers=3, spread=1.0)
43+
>>> rbf_nn = RadialBasisFunctionNeuralNetwork(num_centers=3,spread=1.0)
3844
>>> rbf_nn.num_centers
3945
3
4046
"""
@@ -61,7 +67,7 @@ def _gaussian_rbf(self, input_vector: np.ndarray, center: np.ndarray) -> float:
6167
0.1353352832366127
6268
"""
6369
# Calculate the squared distances
64-
distances = np.linalg.norm(input_data[:, np.newaxis] - centers, axis=2) ** 2
70+
distances = np.linalg.norm(input_vector[:, np.newaxis] - center, axis=2)** 2
6571
return np.exp(-distances / (2 * self.spread**2))
6672

6773
def _compute_rbf_outputs(self, input_data: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)