Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit aced9cf

Browse files
authoredOct 7, 2024··
Update artificial_neural_network.py
1 parent 9024013 commit aced9cf

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed
 

‎neural_network/artificial_neural_network.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
11
"""
22
Simple Artificial Neural Network (ANN)
33
- Feedforward Neural Network with 1 hidden layer and Sigmoid activation.
4-
- Uses Gradient Descent for backpropagation and Mean Squared Error (MSE) as the loss function.
4+
- Uses Gradient Descent for backpropagation and Mean Squared Error (MSE)
5+
as the loss function.
56
- Example demonstrates solving the XOR problem.
67
"""
78

89
import numpy as np
910

1011

1112
class ANN:
13+
"""
14+
Artificial Neural Network (ANN)
15+
16+
- Feedforward Neural Network with 1 hidden layer
17+
and Sigmoid activation.
18+
- Uses Gradient Descent for backpropagation.
19+
- Example demonstrates solving the XOR problem.
20+
"""
21+
1222
def __init__(self, input_size, hidden_size, output_size, learning_rate=0.1):
13-
# Initialize weights
14-
self.weights_input_hidden = np.random.randn(input_size, hidden_size)
15-
self.weights_hidden_output = np.random.randn(hidden_size, output_size)
23+
# Initialize weights using np.random.Generator
24+
rng = np.random.default_rng()
25+
self.weights_input_hidden = rng.standard_normal((input_size, hidden_size))
26+
self.weights_hidden_output = rng.standard_normal((hidden_size, output_size))
27+
1628
# Initialize biases
1729
self.bias_hidden = np.zeros((1, hidden_size))
1830
self.bias_output = np.zeros((1, output_size))
@@ -21,54 +33,54 @@ def __init__(self, input_size, hidden_size, output_size, learning_rate=0.1):
2133
self.learning_rate = learning_rate
2234

2335
def sigmoid(self, x):
36+
"""Sigmoid activation function."""
2437
return 1 / (1 + np.exp(-x))
2538

2639
def sigmoid_derivative(self, x):
40+
"""Derivative of the sigmoid function."""
2741
return x * (1 - x)
2842

29-
def feedforward(self, X):
30-
# Hidden layer
31-
self.hidden_input = np.dot(X, self.weights_input_hidden) + self.bias_hidden
43+
def feedforward(self, x):
44+
"""Forward pass."""
45+
self.hidden_input = np.dot(x, self.weights_input_hidden) + self.bias_hidden
3246
self.hidden_output = self.sigmoid(self.hidden_input)
33-
34-
# Output layer
3547
self.final_input = (
3648
np.dot(self.hidden_output, self.weights_hidden_output) + self.bias_output
3749
)
3850
self.final_output = self.sigmoid(self.final_input)
39-
4051
return self.final_output
4152

42-
def backpropagation(self, X, y, output):
43-
# Calculate the error (Mean Squared Error)
53+
def backpropagation(self, x, y, output):
54+
"""Backpropagation to adjust weights."""
4455
error = y - output
45-
# Gradient for output layer
4656
output_gradient = error * self.sigmoid_derivative(output)
47-
# Error for hidden layer
4857
hidden_error = output_gradient.dot(self.weights_hidden_output.T)
4958
hidden_gradient = hidden_error * self.sigmoid_derivative(self.hidden_output)
50-
# Update weights and biases
59+
5160
self.weights_hidden_output += (
5261
self.hidden_output.T.dot(output_gradient) * self.learning_rate
5362
)
5463
self.bias_output += (
5564
np.sum(output_gradient, axis=0, keepdims=True) * self.learning_rate
5665
)
57-
self.weights_input_hidden += X.T.dot(hidden_gradient) * self.learning_rate
66+
67+
self.weights_input_hidden += x.T.dot(hidden_gradient) * self.learning_rate
5868
self.bias_hidden += (
5969
np.sum(hidden_gradient, axis=0, keepdims=True) * self.learning_rate
6070
)
6171

62-
def train(self, X, y, epochs=10000):
72+
def train(self, x, y, epochs=10000):
73+
"""Train the network."""
6374
for epoch in range(epochs):
64-
output = self.feedforward(X)
65-
self.backpropagation(X, y, output)
75+
output = self.feedforward(x)
76+
self.backpropagation(x, y, output)
6677
if epoch % 1000 == 0:
6778
loss = np.mean(np.square(y - output))
6879
print(f"Epoch {epoch}, Loss: {loss}")
6980

70-
def predict(self, X):
71-
return self.feedforward(X)
81+
def predict(self, x):
82+
"""Make predictions."""
83+
return self.feedforward(x)
7284

7385

7486
if __name__ == "__main__":

0 commit comments

Comments
 (0)
Please sign in to comment.