Skip to content

Commit f054733

Browse files
committed
modified code to meet contribution.md file guidelines
1 parent a2222f1 commit f054733

File tree

1 file changed

+97
-14
lines changed

1 file changed

+97
-14
lines changed

neural_network/lstm.py

+97-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
"""
2+
Name - - LSTM - Long Short-Term Memory Network For Sequence Prediction
3+
Goal - - Predict sequences of data
4+
Detail: Total 3 layers neural network
5+
* Input layer
6+
* LSTM layer
7+
* Output layer
8+
Author: Shashank Tyagi
9+
Github: LEVII007
10+
Date: [Current Date]
11+
"""
12+
113
##### Explanation #####
214
# This script implements a Long Short-Term Memory (LSTM) network to learn and predict sequences of characters.
315
# It uses numpy for numerical operations and tqdm for progress visualization.
@@ -22,14 +34,20 @@
2234
# The script initializes the LSTM network with specified hyperparameters and trains it on the input data.
2335
# Finally, it tests the trained network and prints the accuracy of the predictions.
2436

25-
##### Data #####
26-
2737
##### Imports #####
2838
from tqdm import tqdm
2939
import numpy as np
3040

3141
class LSTM:
32-
def __init__(self, data, hidden_dim=25, epochs=1000, lr=0.05):
42+
def __init__(self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: float = 0.05) -> None:
43+
"""
44+
Initialize the LSTM network with the given data and hyperparameters.
45+
46+
:param data: The input data as a string.
47+
:param hidden_dim: The number of hidden units in the LSTM layer.
48+
:param epochs: The number of training epochs.
49+
:param lr: The learning rate.
50+
"""
3351
self.data = data.lower()
3452
self.hidden_dim = hidden_dim
3553
self.epochs = epochs
@@ -48,12 +66,21 @@ def __init__(self, data, hidden_dim=25, epochs=1000, lr=0.05):
4866
self.initialize_weights()
4967

5068
##### Helper Functions #####
51-
def one_hot_encode(self, char):
69+
def one_hot_encode(self, char: str) -> np.ndarray:
70+
"""
71+
One-hot encode a character.
72+
73+
:param char: The character to encode.
74+
:return: A one-hot encoded vector.
75+
"""
5276
vector = np.zeros((self.char_size, 1))
5377
vector[self.char_to_idx[char]] = 1
5478
return vector
5579

56-
def initialize_weights(self):
80+
def initialize_weights(self) -> None:
81+
"""
82+
Initialize the weights and biases for the LSTM network.
83+
"""
5784
self.wf = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
5885
self.bf = np.zeros((self.hidden_dim, 1))
5986

@@ -69,26 +96,56 @@ def initialize_weights(self):
6996
self.wy = self.init_weights(self.hidden_dim, self.char_size)
7097
self.by = np.zeros((self.char_size, 1))
7198

72-
def init_weights(self, input_dim, output_dim):
99+
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
100+
"""
101+
Initialize weights with random values.
102+
103+
:param input_dim: The input dimension.
104+
:param output_dim: The output dimension.
105+
:return: A matrix of initialized weights.
106+
"""
73107
return np.random.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(6 / (input_dim + output_dim))
74108

75109
##### Activation Functions #####
76-
def sigmoid(self, x, derivative=False):
110+
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
111+
"""
112+
Sigmoid activation function.
113+
114+
:param x: The input array.
115+
:param derivative: Whether to compute the derivative.
116+
:return: The sigmoid activation or its derivative.
117+
"""
77118
if derivative:
78119
return x * (1 - x)
79120
return 1 / (1 + np.exp(-x))
80121

81-
def tanh(self, x, derivative=False):
122+
def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
123+
"""
124+
Tanh activation function.
125+
126+
:param x: The input array.
127+
:param derivative: Whether to compute the derivative.
128+
:return: The tanh activation or its derivative.
129+
"""
82130
if derivative:
83131
return 1 - x ** 2
84132
return np.tanh(x)
85133

86-
def softmax(self, x):
134+
def softmax(self, x: np.ndarray) -> np.ndarray:
135+
"""
136+
Softmax activation function.
137+
138+
:param x: The input array.
139+
:return: The softmax activation.
140+
"""
87141
exp_x = np.exp(x - np.max(x))
88142
return exp_x / exp_x.sum(axis=0)
89143

90144
##### LSTM Network Methods #####
91-
def reset(self):
145+
def reset(self) -> None:
146+
"""
147+
Reset the LSTM network states.
148+
"""
92149
self.concat_inputs = {}
93150

94151
self.hidden_states = {-1: np.zeros((self.hidden_dim, 1))}
@@ -101,7 +158,13 @@ def reset(self):
101158
self.input_gates = {}
102159
self.outputs = {}
103160

104-
def forward(self, inputs):
161+
def forward(self, inputs: list) -> list:
162+
"""
163+
Perform forward propagation through the LSTM network.
164+
165+
:param inputs: The input data as a list of one-hot encoded vectors.
166+
:return: The outputs of the network.
167+
"""
105168
self.reset()
106169

107170
outputs = []
@@ -120,7 +183,13 @@ def forward(self, inputs):
120183

121184
return outputs
122185

123-
def backward(self, errors, inputs):
186+
def backward(self, errors: list, inputs: list) -> None:
187+
"""
188+
Perform backpropagation through time to compute gradients and update weights.
189+
190+
:param errors: The errors at each time step.
191+
:param inputs: The input data as a list of one-hot encoded vectors.
192+
"""
124193
d_wf, d_bf = 0, 0
125194
d_wi, d_bi = 0, 0
126195
d_wc, d_bc = 0, 0
@@ -186,7 +255,10 @@ def backward(self, errors, inputs):
186255
self.wy += d_wy * self.lr
187256
self.by += d_by * self.lr
188257

189-
def train(self):
258+
def train(self) -> None:
259+
"""
260+
Train the LSTM network on the input data.
261+
"""
190262
inputs = [self.one_hot_encode(char) for char in self.train_X]
191263

192264
for _ in tqdm(range(self.epochs)):
@@ -199,7 +271,10 @@ def train(self):
199271

200272
self.backward(errors, self.concat_inputs)
201273

202-
def test(self):
274+
def test(self) -> None:
275+
"""
276+
Test the trained LSTM network on the input data and print the accuracy.
277+
"""
203278
accuracy = 0
204279
probabilities = self.forward([self.one_hot_encode(char) for char in self.train_X])
205280

@@ -229,6 +304,14 @@ def test(self):
229304
##### Testing #####
230305
# lstm.test()
231306

307+
if __name__ == "__main__":
308+
# Initialize Network
309+
# lstm = LSTM(data=data, hidden_dim=25, epochs=1000, lr=0.05)
310+
311+
##### Training #####
312+
# lstm.train()
232313

314+
##### Testing #####
315+
# lstm.test()
233316

234317
# testing can be done by uncommenting the above lines of code.

0 commit comments

Comments
 (0)