Skip to content

Commit bb7ac35

Browse files
Update radial_basis_function_neural_network.py
1 parent f66b55f commit bb7ac35

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

neural_network/radial_basis_function_neural_network.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from typing import List, Optional
32

43

54
class RadialBasisFunctionNeuralNetwork:
@@ -10,6 +9,7 @@ class RadialBasisFunctionNeuralNetwork:
109
centers (np.ndarray): Centers of the radial basis functions.
1110
weights (np.ndarray): Weights for the output layer.
1211
sigma (float): Spread of the radial basis functions.
12+
1313
Reference:
1414
Radial Basis Function Network: https://en.wikipedia.org/wiki/Radial_basis_function_network
1515
"""
@@ -24,8 +24,8 @@ def __init__(self, n_centers: int, sigma: float):
2424
"""
2525
self.n_centers = n_centers
2626
self.sigma = sigma
27-
self.centers: Optional[np.ndarray] = None # To be initialized during training
28-
self.weights: Optional[np.ndarray] = None # To be initialized during training
27+
self.centers: np.ndarray | None = None # To be initialized during training
28+
self.weights: np.ndarray | None = None # To be initialized during training
2929

3030
def _gaussian(self, x: np.ndarray, center: np.ndarray) -> float:
3131
"""
@@ -45,57 +45,57 @@ def _gaussian(self, x: np.ndarray, center: np.ndarray) -> float:
4545
"""
4646
return np.exp(-(np.linalg.norm(x - center) ** 2) / (2 * self.sigma**2))
4747

48-
def _compute_rbf(self, X: np.ndarray) -> np.ndarray:
48+
def _compute_rbf(self, x: np.ndarray) -> np.ndarray:
4949
"""
5050
Compute the output of the radial basis functions for input data.
5151
5252
Args:
53-
X (np.ndarray): Input data matrix (num_samples x num_features).
53+
x (np.ndarray): Input data matrix (num_samples x num_features).
5454
5555
Returns:
5656
np.ndarray: A matrix of shape (num_samples x n_centers) containing the RBF outputs.
5757
"""
58-
rbf_outputs = np.zeros((X.shape[0], self.n_centers))
58+
rbf_outputs = np.zeros((x.shape[0], self.n_centers))
5959
for i, center in enumerate(self.centers):
60-
for j in range(X.shape[0]):
61-
rbf_outputs[j, i] = self._gaussian(X[j], center)
60+
for j in range(x.shape[0]):
61+
rbf_outputs[j, i] = self._gaussian(x[j], center)
6262
return rbf_outputs
6363

64-
def fit(self, X: np.ndarray, y: np.ndarray):
64+
def fit(self, x: np.ndarray, y: np.ndarray):
6565
"""
6666
Train the RBFNN on the provided data.
6767
6868
Args:
69-
X (np.ndarray): Input data matrix (num_samples x num_features).
69+
x (np.ndarray): Input data matrix (num_samples x num_features).
7070
y (np.ndarray): Target values (num_samples x output_dim).
7171
7272
Raises:
73-
ValueError: If number of samples in X and y do not match.
73+
ValueError: If number of samples in x and y do not match.
7474
"""
75-
if X.shape[0] != y.shape[0]:
76-
raise ValueError("Number of samples in X and y must match.")
75+
if x.shape[0] != y.shape[0]:
76+
raise ValueError("Number of samples in x and y must match.")
7777

78-
# Initialize centers using random samples from X
79-
random_indices = np.random.choice(X.shape[0], self.n_centers, replace=False)
80-
self.centers = X[random_indices]
78+
# Initialize centers using random samples from x
79+
random_indices = np.random.choice(x.shape[0], self.n_centers, replace=False)
80+
self.centers = x[random_indices]
8181

8282
# Compute the RBF outputs for the training data
83-
rbf_outputs = self._compute_rbf(X)
83+
rbf_outputs = self._compute_rbf(x)
8484

8585
# Calculate weights using the pseudo-inverse
8686
self.weights = np.linalg.pinv(rbf_outputs).dot(y)
8787

88-
def predict(self, X: np.ndarray) -> np.ndarray:
88+
def predict(self, x: np.ndarray) -> np.ndarray:
8989
"""
9090
Predict the output for the given input data.
9191
9292
Args:
93-
X (np.ndarray): Input data matrix (num_samples x num_features).
93+
x (np.ndarray): Input data matrix (num_samples x num_features).
9494
9595
Returns:
9696
np.ndarray: Predicted values (num_samples x output_dim).
9797
"""
98-
rbf_outputs = self._compute_rbf(X)
98+
rbf_outputs = self._compute_rbf(x)
9999
return rbf_outputs.dot(self.weights)
100100

101101

@@ -113,9 +113,12 @@ def predict(self, X: np.ndarray) -> np.ndarray:
113113
predictions = rbf_nn.predict(X)
114114
print("Predictions:\n", predictions)
115115

116-
# Expected Output:
116+
# Sample Expected Output:
117117
# Predictions:
118118
# [[0.24826229]
119119
# [0.06598867]
120120
# [0.06598867]
121121
# [0.24826229]]
122+
123+
124+

0 commit comments

Comments
 (0)