11
11
"""
12
12
13
13
##### Explanation #####
14
- # This script implements a Long Short-Term Memory (LSTM) network to learn
14
+ # This script implements a Long Short-Term Memory (LSTM) network to learn
15
15
# and predict sequences of characters.
16
16
# It uses numpy for numerical operations and tqdm for progress visualization.
17
17
18
- # The data is a paragraph about LSTM, converted to lowercase and split into
18
+ # The data is a paragraph about LSTM, converted to lowercase and split into
19
19
# characters. Each character is one-hot encoded for training.
20
20
21
- # The LSTM class initializes weights and biases for the forget, input, candidate,
21
+ # The LSTM class initializes weights and biases for the forget, input, candidate,
22
22
# and output gates. It also initializes weights and biases for the final output layer.
23
23
24
- # The forward method performs forward propagation through the LSTM network,
25
- # computing hidden and cell states. It uses sigmoid and tanh activation
24
+ # The forward method performs forward propagation through the LSTM network,
25
+ # computing hidden and cell states. It uses sigmoid and tanh activation
26
26
# functions for the gates and cell states.
27
27
28
- # The backward method performs backpropagation through time, computing gradients
29
- # for the weights and biases. It updates the weights and biases using
28
+ # The backward method performs backpropagation through time, computing gradients
29
+ # for the weights and biases. It updates the weights and biases using
30
30
# the computed gradients and the learning rate.
31
31
32
- # The train method trains the LSTM network on the input data for a specified
33
- # number of epochs. It uses one-hot encoded inputs and computes errors
32
+ # The train method trains the LSTM network on the input data for a specified
33
+ # number of epochs. It uses one-hot encoded inputs and computes errors
34
34
# using the softmax function.
35
35
36
- # The test method evaluates the trained LSTM network on the input data,
36
+ # The test method evaluates the trained LSTM network on the input data,
37
37
# computing accuracy based on predictions.
38
38
39
- # The script initializes the LSTM network with specified hyperparameters
40
- # and trains it on the input data. Finally, it tests the trained network
39
+ # The script initializes the LSTM network with specified hyperparameters
40
+ # and trains it on the input data. Finally, it tests the trained network
41
41
# and prints the accuracy of the predictions.
42
42
43
43
##### Imports #####
44
44
from tqdm import tqdm
45
45
import numpy as np
46
46
47
+
47
48
class LSTM :
48
- def __init__ (self , data : str , hidden_dim : int = 25 ,
49
- epochs : int = 1000 , lr : float = 0.05 ) -> None :
49
+ def __init__ (
50
+ self , data : str , hidden_dim : int = 25 , epochs : int = 1000 , lr : float = 0.05
51
+ ) -> None :
50
52
"""
51
53
Initialize the LSTM network with the given data and hyperparameters.
52
-
54
+
53
55
:param data: The input data as a string.
54
56
:param hidden_dim: The number of hidden units in the LSTM layer.
55
57
:param epochs: The number of training epochs.
@@ -63,7 +65,7 @@ def __init__(self, data: str, hidden_dim: int = 25,
63
65
self .chars = set (self .data )
64
66
self .data_size , self .char_size = len (self .data ), len (self .chars )
65
67
66
- print (f' Data size: { self .data_size } , Char Size: { self .char_size } ' )
68
+ print (f" Data size: { self .data_size } , Char Size: { self .char_size } " )
67
69
68
70
self .char_to_idx = {c : i for i , c in enumerate (self .chars )}
69
71
self .idx_to_char = {i : c for i , c in enumerate (self .chars )}
@@ -76,7 +78,7 @@ def __init__(self, data: str, hidden_dim: int = 25,
76
78
def one_hot_encode (self , char : str ) -> np .ndarray :
77
79
"""
78
80
One-hot encode a character.
79
-
81
+
80
82
:param char: The character to encode.
81
83
:return: A one-hot encoded vector.
82
84
"""
@@ -88,20 +90,16 @@ def initialize_weights(self) -> None:
88
90
"""
89
91
Initialize the weights and biases for the LSTM network.
90
92
"""
91
- self .wf = self .init_weights (self .char_size + self .hidden_dim ,
92
- self .hidden_dim )
93
+ self .wf = self .init_weights (self .char_size + self .hidden_dim , self .hidden_dim )
93
94
self .bf = np .zeros ((self .hidden_dim , 1 ))
94
95
95
- self .wi = self .init_weights (self .char_size + self .hidden_dim ,
96
- self .hidden_dim )
96
+ self .wi = self .init_weights (self .char_size + self .hidden_dim , self .hidden_dim )
97
97
self .bi = np .zeros ((self .hidden_dim , 1 ))
98
98
99
- self .wc = self .init_weights (self .char_size + self .hidden_dim ,
100
- self .hidden_dim )
99
+ self .wc = self .init_weights (self .char_size + self .hidden_dim , self .hidden_dim )
101
100
self .bc = np .zeros ((self .hidden_dim , 1 ))
102
101
103
- self .wo = self .init_weights (self .char_size + self .hidden_dim ,
104
- self .hidden_dim )
102
+ self .wo = self .init_weights (self .char_size + self .hidden_dim , self .hidden_dim )
105
103
self .bo = np .zeros ((self .hidden_dim , 1 ))
106
104
107
105
self .wy = self .init_weights (self .hidden_dim , self .char_size )
@@ -110,19 +108,20 @@ def initialize_weights(self) -> None:
110
108
def init_weights (self , input_dim : int , output_dim : int ) -> np .ndarray :
111
109
"""
112
110
Initialize weights with random values.
113
-
111
+
114
112
:param input_dim: The input dimension.
115
113
:param output_dim: The output dimension.
116
114
:return: A matrix of initialized weights.
117
115
"""
118
- return np .random .uniform (- 1 , 1 , (output_dim , input_dim )) * \
119
- np .sqrt (6 / (input_dim + output_dim ))
116
+ return np .random .uniform (- 1 , 1 , (output_dim , input_dim )) * np .sqrt (
117
+ 6 / (input_dim + output_dim )
118
+ )
120
119
121
120
##### Activation Functions #####
122
121
def sigmoid (self , x : np .ndarray , derivative : bool = False ) -> np .ndarray :
123
122
"""
124
123
Sigmoid activation function.
125
-
124
+
126
125
:param x: The input array.
127
126
:param derivative: Whether to compute the derivative.
128
127
:return: The sigmoid activation or its derivative.
@@ -134,19 +133,19 @@ def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
134
133
def tanh (self , x : np .ndarray , derivative : bool = False ) -> np .ndarray :
135
134
"""
136
135
Tanh activation function.
137
-
136
+
138
137
:param x: The input array.
139
138
:param derivative: Whether to compute the derivative.
140
139
:return: The tanh activation or its derivative.
141
140
"""
142
141
if derivative :
143
- return 1 - x ** 2
142
+ return 1 - x ** 2
144
143
return np .tanh (x )
145
144
146
145
def softmax (self , x : np .ndarray ) -> np .ndarray :
147
146
"""
148
147
Softmax activation function.
149
-
148
+
150
149
:param x: The input array.
151
150
:return: The softmax activation.
152
151
"""
@@ -173,7 +172,7 @@ def reset(self) -> None:
173
172
def forward (self , inputs : list ) -> list :
174
173
"""
175
174
Perform forward propagation through the LSTM network.
176
-
175
+
177
176
:param inputs: The input data as a list of one-hot encoded vectors.
178
177
:return: The outputs of the network.
179
178
"""
@@ -182,21 +181,29 @@ def forward(self, inputs: list) -> list:
182
181
outputs = []
183
182
for t in range (len (inputs )):
184
183
self .concat_inputs [t ] = np .concatenate (
185
- (self .hidden_states [t - 1 ], inputs [t ]))
186
-
187
- self .forget_gates [t ] = self .sigmoid (np .dot (self .wf ,
188
- self .concat_inputs [t ]) + self .bf )
189
- self .input_gates [t ] = self .sigmoid (np .dot (self .wi ,
190
- self .concat_inputs [t ]) + self .bi )
191
- self .candidate_gates [t ] = self .tanh (np .dot (self .wc ,
192
- self .concat_inputs [t ]) + self .bc )
193
- self .output_gates [t ] = self .sigmoid (np .dot (self .wo ,
194
- self .concat_inputs [t ]) + self .bo )
195
-
196
- self .cell_states [t ] = self .forget_gates [t ] * self .cell_states [t - 1 ] + \
197
- self .input_gates [t ] * self .candidate_gates [t ]
198
- self .hidden_states [t ] = self .output_gates [t ] * \
199
- self .tanh (self .cell_states [t ])
184
+ (self .hidden_states [t - 1 ], inputs [t ])
185
+ )
186
+
187
+ self .forget_gates [t ] = self .sigmoid (
188
+ np .dot (self .wf , self .concat_inputs [t ]) + self .bf
189
+ )
190
+ self .input_gates [t ] = self .sigmoid (
191
+ np .dot (self .wi , self .concat_inputs [t ]) + self .bi
192
+ )
193
+ self .candidate_gates [t ] = self .tanh (
194
+ np .dot (self .wc , self .concat_inputs [t ]) + self .bc
195
+ )
196
+ self .output_gates [t ] = self .sigmoid (
197
+ np .dot (self .wo , self .concat_inputs [t ]) + self .bo
198
+ )
199
+
200
+ self .cell_states [t ] = (
201
+ self .forget_gates [t ] * self .cell_states [t - 1 ]
202
+ + self .input_gates [t ] * self .candidate_gates [t ]
203
+ )
204
+ self .hidden_states [t ] = self .output_gates [t ] * self .tanh (
205
+ self .cell_states [t ]
206
+ )
200
207
201
208
outputs .append (np .dot (self .wy , self .hidden_states [t ]) + self .by )
202
209
@@ -205,7 +212,7 @@ def forward(self, inputs: list) -> list:
205
212
def backward (self , errors : list , inputs : list ) -> None :
206
213
"""
207
214
Perform backpropagation through time to compute gradients and update weights.
208
-
215
+
209
216
:param errors: The errors at each time step.
210
217
:param inputs: The input data as a list of one-hot encoded vectors.
211
218
"""
@@ -215,8 +222,10 @@ def backward(self, errors: list, inputs: list) -> None:
215
222
d_wo , d_bo = 0 , 0
216
223
d_wy , d_by = 0 , 0
217
224
218
- dh_next , dc_next = np .zeros_like (self .hidden_states [0 ]), \
219
- np .zeros_like (self .cell_states [0 ])
225
+ dh_next , dc_next = (
226
+ np .zeros_like (self .hidden_states [0 ]),
227
+ np .zeros_like (self .cell_states [0 ]),
228
+ )
220
229
for t in reversed (range (len (inputs ))):
221
230
error = errors [t ]
222
231
@@ -228,45 +237,69 @@ def backward(self, errors: list, inputs: list) -> None:
228
237
d_hs = np .dot (self .wy .T , error ) + dh_next
229
238
230
239
# Output Gate Weights and Biases Errors
231
- d_o = self .tanh (self .cell_states [t ]) * d_hs * \
232
- self .sigmoid (self .output_gates [t ], derivative = True )
240
+ d_o = (
241
+ self .tanh (self .cell_states [t ])
242
+ * d_hs
243
+ * self .sigmoid (self .output_gates [t ], derivative = True )
244
+ )
233
245
d_wo += np .dot (d_o , inputs [t ].T )
234
246
d_bo += d_o
235
247
236
248
# Cell State Error
237
- d_cs = self .tanh (self .tanh (self .cell_states [t ]),
238
- derivative = True ) * self .output_gates [t ] * d_hs + dc_next
249
+ d_cs = (
250
+ self .tanh (self .tanh (self .cell_states [t ]), derivative = True )
251
+ * self .output_gates [t ]
252
+ * d_hs
253
+ + dc_next
254
+ )
239
255
240
256
# Forget Gate Weights and Biases Errors
241
- d_f = d_cs * self .cell_states [t - 1 ] * \
242
- self .sigmoid (self .forget_gates [t ], derivative = True )
257
+ d_f = (
258
+ d_cs
259
+ * self .cell_states [t - 1 ]
260
+ * self .sigmoid (self .forget_gates [t ], derivative = True )
261
+ )
243
262
d_wf += np .dot (d_f , inputs [t ].T )
244
263
d_bf += d_f
245
264
246
265
# Input Gate Weights and Biases Errors
247
- d_i = d_cs * self .candidate_gates [t ] * \
248
- self .sigmoid (self .input_gates [t ], derivative = True )
266
+ d_i = (
267
+ d_cs
268
+ * self .candidate_gates [t ]
269
+ * self .sigmoid (self .input_gates [t ], derivative = True )
270
+ )
249
271
d_wi += np .dot (d_i , inputs [t ].T )
250
272
d_bi += d_i
251
273
252
274
# Candidate Gate Weights and Biases Errors
253
- d_c = d_cs * self .input_gates [t ] * self .tanh (self .candidate_gates [t ],
254
- derivative = True )
275
+ d_c = (
276
+ d_cs
277
+ * self .input_gates [t ]
278
+ * self .tanh (self .candidate_gates [t ], derivative = True )
279
+ )
255
280
d_wc += np .dot (d_c , inputs [t ].T )
256
281
d_bc += d_c
257
282
258
283
# Update the next hidden and cell state errors
259
- dh_next = np .dot (self .wf .T , d_f ) + np .dot (self .wi .T , d_i ) + \
260
- np .dot (self .wo .T , d_o ) + np .dot (self .wc .T , d_c )
284
+ dh_next = (
285
+ np .dot (self .wf .T , d_f )
286
+ + np .dot (self .wi .T , d_i )
287
+ + np .dot (self .wo .T , d_o )
288
+ + np .dot (self .wc .T , d_c )
289
+ )
261
290
dc_next = d_cs * self .forget_gates [t ]
262
291
263
292
# Apply gradients to weights and biases
264
- for param , grad in zip ([self .wf , self .wi , self .wc , self .wo , self .wy ],
265
- [d_wf , d_wi , d_wc , d_wo , d_wy ]):
293
+ for param , grad in zip (
294
+ [self .wf , self .wi , self .wc , self .wo , self .wy ],
295
+ [d_wf , d_wi , d_wc , d_wo , d_wy ],
296
+ ):
266
297
param -= self .lr * grad
267
298
268
- for param , grad in zip ([self .bf , self .bi , self .bc , self .bo , self .by ],
269
- [d_bf , d_bi , d_bc , d_bo , d_by ]):
299
+ for param , grad in zip (
300
+ [self .bf , self .bi , self .bc , self .bo , self .by ],
301
+ [d_bf , d_bi , d_bc , d_bo , d_by ],
302
+ ):
270
303
param -= self .lr * grad
271
304
272
305
def train (self ) -> None :
@@ -289,7 +322,7 @@ def train(self) -> None:
289
322
def predict (self , inputs : list ) -> str :
290
323
"""
291
324
Predict the next character in the sequence.
292
-
325
+
293
326
:param inputs: The input data as a list of one-hot encoded vectors.
294
327
:return: The predicted character.
295
328
"""
@@ -301,11 +334,13 @@ def test(self) -> None:
301
334
Test the LSTM network on the input data and compute accuracy.
302
335
"""
303
336
inputs = [self .one_hot_encode (char ) for char in self .train_X ]
304
- correct_predictions = sum (self .idx_to_char [np .argmax (self .softmax (output ))] == target
305
- for output , target in zip (self .forward (inputs ), self .train_y ))
337
+ correct_predictions = sum (
338
+ self .idx_to_char [np .argmax (self .softmax (output ))] == target
339
+ for output , target in zip (self .forward (inputs ), self .train_y )
340
+ )
306
341
307
342
accuracy = (correct_predictions / len (self .train_y )) * 100
308
- print (f' Accuracy: { accuracy :.2f} %' )
343
+ print (f" Accuracy: { accuracy :.2f} %" )
309
344
310
345
311
346
if __name__ == "__main__" :
0 commit comments