Skip to content

Commit 39fd713

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4c2ec80 commit 39fd713

File tree

1 file changed

+106
-71
lines changed

1 file changed

+106
-71
lines changed

neural_network/lstm.py

+106-71
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,47 @@
1111
"""
1212

1313
##### Explanation #####
14-
# This script implements a Long Short-Term Memory (LSTM) network to learn
14+
# This script implements a Long Short-Term Memory (LSTM) network to learn
1515
# and predict sequences of characters.
1616
# It uses numpy for numerical operations and tqdm for progress visualization.
1717

18-
# The data is a paragraph about LSTM, converted to lowercase and split into
18+
# The data is a paragraph about LSTM, converted to lowercase and split into
1919
# characters. Each character is one-hot encoded for training.
2020

21-
# The LSTM class initializes weights and biases for the forget, input, candidate,
21+
# The LSTM class initializes weights and biases for the forget, input, candidate,
2222
# and output gates. It also initializes weights and biases for the final output layer.
2323

24-
# The forward method performs forward propagation through the LSTM network,
25-
# computing hidden and cell states. It uses sigmoid and tanh activation
24+
# The forward method performs forward propagation through the LSTM network,
25+
# computing hidden and cell states. It uses sigmoid and tanh activation
2626
# functions for the gates and cell states.
2727

28-
# The backward method performs backpropagation through time, computing gradients
29-
# for the weights and biases. It updates the weights and biases using
28+
# The backward method performs backpropagation through time, computing gradients
29+
# for the weights and biases. It updates the weights and biases using
3030
# the computed gradients and the learning rate.
3131

32-
# The train method trains the LSTM network on the input data for a specified
33-
# number of epochs. It uses one-hot encoded inputs and computes errors
32+
# The train method trains the LSTM network on the input data for a specified
33+
# number of epochs. It uses one-hot encoded inputs and computes errors
3434
# using the softmax function.
3535

36-
# The test method evaluates the trained LSTM network on the input data,
36+
# The test method evaluates the trained LSTM network on the input data,
3737
# computing accuracy based on predictions.
3838

39-
# The script initializes the LSTM network with specified hyperparameters
40-
# and trains it on the input data. Finally, it tests the trained network
39+
# The script initializes the LSTM network with specified hyperparameters
40+
# and trains it on the input data. Finally, it tests the trained network
4141
# and prints the accuracy of the predictions.
4242

4343
##### Imports #####
4444
from tqdm import tqdm
4545
import numpy as np
4646

47+
4748
class LSTM:
48-
def __init__(self, data: str, hidden_dim: int = 25,
49-
epochs: int = 1000, lr: float = 0.05) -> None:
49+
def __init__(
50+
self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: float = 0.05
51+
) -> None:
5052
"""
5153
Initialize the LSTM network with the given data and hyperparameters.
52-
54+
5355
:param data: The input data as a string.
5456
:param hidden_dim: The number of hidden units in the LSTM layer.
5557
:param epochs: The number of training epochs.
@@ -63,7 +65,7 @@ def __init__(self, data: str, hidden_dim: int = 25,
6365
self.chars = set(self.data)
6466
self.data_size, self.char_size = len(self.data), len(self.chars)
6567

66-
print(f'Data size: {self.data_size}, Char Size: {self.char_size}')
68+
print(f"Data size: {self.data_size}, Char Size: {self.char_size}")
6769

6870
self.char_to_idx = {c: i for i, c in enumerate(self.chars)}
6971
self.idx_to_char = {i: c for i, c in enumerate(self.chars)}
@@ -76,7 +78,7 @@ def __init__(self, data: str, hidden_dim: int = 25,
7678
def one_hot_encode(self, char: str) -> np.ndarray:
7779
"""
7880
One-hot encode a character.
79-
81+
8082
:param char: The character to encode.
8183
:return: A one-hot encoded vector.
8284
"""
@@ -88,20 +90,16 @@ def initialize_weights(self) -> None:
8890
"""
8991
Initialize the weights and biases for the LSTM network.
9092
"""
91-
self.wf = self.init_weights(self.char_size + self.hidden_dim,
92-
self.hidden_dim)
93+
self.wf = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
9394
self.bf = np.zeros((self.hidden_dim, 1))
9495

95-
self.wi = self.init_weights(self.char_size + self.hidden_dim,
96-
self.hidden_dim)
96+
self.wi = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
9797
self.bi = np.zeros((self.hidden_dim, 1))
9898

99-
self.wc = self.init_weights(self.char_size + self.hidden_dim,
100-
self.hidden_dim)
99+
self.wc = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
101100
self.bc = np.zeros((self.hidden_dim, 1))
102101

103-
self.wo = self.init_weights(self.char_size + self.hidden_dim,
104-
self.hidden_dim)
102+
self.wo = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
105103
self.bo = np.zeros((self.hidden_dim, 1))
106104

107105
self.wy = self.init_weights(self.hidden_dim, self.char_size)
@@ -110,19 +108,20 @@ def initialize_weights(self) -> None:
110108
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
111109
"""
112110
Initialize weights with random values.
113-
111+
114112
:param input_dim: The input dimension.
115113
:param output_dim: The output dimension.
116114
:return: A matrix of initialized weights.
117115
"""
118-
return np.random.uniform(-1, 1, (output_dim, input_dim)) * \
119-
np.sqrt(6 / (input_dim + output_dim))
116+
return np.random.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
117+
6 / (input_dim + output_dim)
118+
)
120119

121120
##### Activation Functions #####
122121
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
123122
"""
124123
Sigmoid activation function.
125-
124+
126125
:param x: The input array.
127126
:param derivative: Whether to compute the derivative.
128127
:return: The sigmoid activation or its derivative.
@@ -134,19 +133,19 @@ def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
134133
def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
135134
"""
136135
Tanh activation function.
137-
136+
138137
:param x: The input array.
139138
:param derivative: Whether to compute the derivative.
140139
:return: The tanh activation or its derivative.
141140
"""
142141
if derivative:
143-
return 1 - x ** 2
142+
return 1 - x**2
144143
return np.tanh(x)
145144

146145
def softmax(self, x: np.ndarray) -> np.ndarray:
147146
"""
148147
Softmax activation function.
149-
148+
150149
:param x: The input array.
151150
:return: The softmax activation.
152151
"""
@@ -173,7 +172,7 @@ def reset(self) -> None:
173172
def forward(self, inputs: list) -> list:
174173
"""
175174
Perform forward propagation through the LSTM network.
176-
175+
177176
:param inputs: The input data as a list of one-hot encoded vectors.
178177
:return: The outputs of the network.
179178
"""
@@ -182,21 +181,29 @@ def forward(self, inputs: list) -> list:
182181
outputs = []
183182
for t in range(len(inputs)):
184183
self.concat_inputs[t] = np.concatenate(
185-
(self.hidden_states[t - 1], inputs[t]))
186-
187-
self.forget_gates[t] = self.sigmoid(np.dot(self.wf,
188-
self.concat_inputs[t]) + self.bf)
189-
self.input_gates[t] = self.sigmoid(np.dot(self.wi,
190-
self.concat_inputs[t]) + self.bi)
191-
self.candidate_gates[t] = self.tanh(np.dot(self.wc,
192-
self.concat_inputs[t]) + self.bc)
193-
self.output_gates[t] = self.sigmoid(np.dot(self.wo,
194-
self.concat_inputs[t]) + self.bo)
195-
196-
self.cell_states[t] = self.forget_gates[t] * self.cell_states[t - 1] + \
197-
self.input_gates[t] * self.candidate_gates[t]
198-
self.hidden_states[t] = self.output_gates[t] * \
199-
self.tanh(self.cell_states[t])
184+
(self.hidden_states[t - 1], inputs[t])
185+
)
186+
187+
self.forget_gates[t] = self.sigmoid(
188+
np.dot(self.wf, self.concat_inputs[t]) + self.bf
189+
)
190+
self.input_gates[t] = self.sigmoid(
191+
np.dot(self.wi, self.concat_inputs[t]) + self.bi
192+
)
193+
self.candidate_gates[t] = self.tanh(
194+
np.dot(self.wc, self.concat_inputs[t]) + self.bc
195+
)
196+
self.output_gates[t] = self.sigmoid(
197+
np.dot(self.wo, self.concat_inputs[t]) + self.bo
198+
)
199+
200+
self.cell_states[t] = (
201+
self.forget_gates[t] * self.cell_states[t - 1]
202+
+ self.input_gates[t] * self.candidate_gates[t]
203+
)
204+
self.hidden_states[t] = self.output_gates[t] * self.tanh(
205+
self.cell_states[t]
206+
)
200207

201208
outputs.append(np.dot(self.wy, self.hidden_states[t]) + self.by)
202209

@@ -205,7 +212,7 @@ def forward(self, inputs: list) -> list:
205212
def backward(self, errors: list, inputs: list) -> None:
206213
"""
207214
Perform backpropagation through time to compute gradients and update weights.
208-
215+
209216
:param errors: The errors at each time step.
210217
:param inputs: The input data as a list of one-hot encoded vectors.
211218
"""
@@ -215,8 +222,10 @@ def backward(self, errors: list, inputs: list) -> None:
215222
d_wo, d_bo = 0, 0
216223
d_wy, d_by = 0, 0
217224

218-
dh_next, dc_next = np.zeros_like(self.hidden_states[0]), \
219-
np.zeros_like(self.cell_states[0])
225+
dh_next, dc_next = (
226+
np.zeros_like(self.hidden_states[0]),
227+
np.zeros_like(self.cell_states[0]),
228+
)
220229
for t in reversed(range(len(inputs))):
221230
error = errors[t]
222231

@@ -228,45 +237,69 @@ def backward(self, errors: list, inputs: list) -> None:
228237
d_hs = np.dot(self.wy.T, error) + dh_next
229238

230239
# Output Gate Weights and Biases Errors
231-
d_o = self.tanh(self.cell_states[t]) * d_hs * \
232-
self.sigmoid(self.output_gates[t], derivative=True)
240+
d_o = (
241+
self.tanh(self.cell_states[t])
242+
* d_hs
243+
* self.sigmoid(self.output_gates[t], derivative=True)
244+
)
233245
d_wo += np.dot(d_o, inputs[t].T)
234246
d_bo += d_o
235247

236248
# Cell State Error
237-
d_cs = self.tanh(self.tanh(self.cell_states[t]),
238-
derivative=True) * self.output_gates[t] * d_hs + dc_next
249+
d_cs = (
250+
self.tanh(self.tanh(self.cell_states[t]), derivative=True)
251+
* self.output_gates[t]
252+
* d_hs
253+
+ dc_next
254+
)
239255

240256
# Forget Gate Weights and Biases Errors
241-
d_f = d_cs * self.cell_states[t - 1] * \
242-
self.sigmoid(self.forget_gates[t], derivative=True)
257+
d_f = (
258+
d_cs
259+
* self.cell_states[t - 1]
260+
* self.sigmoid(self.forget_gates[t], derivative=True)
261+
)
243262
d_wf += np.dot(d_f, inputs[t].T)
244263
d_bf += d_f
245264

246265
# Input Gate Weights and Biases Errors
247-
d_i = d_cs * self.candidate_gates[t] * \
248-
self.sigmoid(self.input_gates[t], derivative=True)
266+
d_i = (
267+
d_cs
268+
* self.candidate_gates[t]
269+
* self.sigmoid(self.input_gates[t], derivative=True)
270+
)
249271
d_wi += np.dot(d_i, inputs[t].T)
250272
d_bi += d_i
251273

252274
# Candidate Gate Weights and Biases Errors
253-
d_c = d_cs * self.input_gates[t] * self.tanh(self.candidate_gates[t],
254-
derivative=True)
275+
d_c = (
276+
d_cs
277+
* self.input_gates[t]
278+
* self.tanh(self.candidate_gates[t], derivative=True)
279+
)
255280
d_wc += np.dot(d_c, inputs[t].T)
256281
d_bc += d_c
257282

258283
# Update the next hidden and cell state errors
259-
dh_next = np.dot(self.wf.T, d_f) + np.dot(self.wi.T, d_i) + \
260-
np.dot(self.wo.T, d_o) + np.dot(self.wc.T, d_c)
284+
dh_next = (
285+
np.dot(self.wf.T, d_f)
286+
+ np.dot(self.wi.T, d_i)
287+
+ np.dot(self.wo.T, d_o)
288+
+ np.dot(self.wc.T, d_c)
289+
)
261290
dc_next = d_cs * self.forget_gates[t]
262291

263292
# Apply gradients to weights and biases
264-
for param, grad in zip([self.wf, self.wi, self.wc, self.wo, self.wy],
265-
[d_wf, d_wi, d_wc, d_wo, d_wy]):
293+
for param, grad in zip(
294+
[self.wf, self.wi, self.wc, self.wo, self.wy],
295+
[d_wf, d_wi, d_wc, d_wo, d_wy],
296+
):
266297
param -= self.lr * grad
267298

268-
for param, grad in zip([self.bf, self.bi, self.bc, self.bo, self.by],
269-
[d_bf, d_bi, d_bc, d_bo, d_by]):
299+
for param, grad in zip(
300+
[self.bf, self.bi, self.bc, self.bo, self.by],
301+
[d_bf, d_bi, d_bc, d_bo, d_by],
302+
):
270303
param -= self.lr * grad
271304

272305
def train(self) -> None:
@@ -289,7 +322,7 @@ def train(self) -> None:
289322
def predict(self, inputs: list) -> str:
290323
"""
291324
Predict the next character in the sequence.
292-
325+
293326
:param inputs: The input data as a list of one-hot encoded vectors.
294327
:return: The predicted character.
295328
"""
@@ -301,11 +334,13 @@ def test(self) -> None:
301334
Test the LSTM network on the input data and compute accuracy.
302335
"""
303336
inputs = [self.one_hot_encode(char) for char in self.train_X]
304-
correct_predictions = sum(self.idx_to_char[np.argmax(self.softmax(output))] == target
305-
for output, target in zip(self.forward(inputs), self.train_y))
337+
correct_predictions = sum(
338+
self.idx_to_char[np.argmax(self.softmax(output))] == target
339+
for output, target in zip(self.forward(inputs), self.train_y)
340+
)
306341

307342
accuracy = (correct_predictions / len(self.train_y)) * 100
308-
print(f'Accuracy: {accuracy:.2f}%')
343+
print(f"Accuracy: {accuracy:.2f}%")
309344

310345

311346
if __name__ == "__main__":

0 commit comments

Comments
 (0)