@@ -175,12 +175,20 @@ def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
175
175
:param input_dim: The input dimension.
176
176
:param output_dim: The output dimension.
177
177
:return: A matrix of initialized weights.
178
+
179
+ Example:
180
+ >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
181
+ >>> weights = lstm.init_weights(5, 10)
182
+ >>> isinstance(weights, np.ndarray)
183
+ True
184
+ >>> weights.shape
185
+ (10, 5)
178
186
"""
179
187
return self .random_generator .uniform (- 1 , 1 , (output_dim , input_dim )) * np .sqrt (
180
188
6 / (input_dim + output_dim )
181
189
)
182
190
183
- def sigmoid (self , x : np .ndarray , derivative : bool = False ) -> np .ndarray :
191
+ def sigmoid (self , input_array : np .ndarray , derivative : bool = False ) -> np .ndarray :
184
192
"""
185
193
Sigmoid activation function.
186
194
@@ -199,10 +207,10 @@ def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
199
207
array([[0.197, 0.105, 0.045]])
200
208
"""
201
209
if derivative :
202
- return x * (1 - x )
203
- return 1 / (1 + np .exp (- x ))
210
+ return input_array * (1 - input_array )
211
+ return 1 / (1 + np .exp (- input_array ))
204
212
205
- def tanh (self , x : np .ndarray , derivative : bool = False ) -> np .ndarray :
213
+ def tanh (self , input_array : np .ndarray , derivative : bool = False ) -> np .ndarray :
206
214
"""
207
215
Tanh activation function.
208
216
@@ -221,10 +229,10 @@ def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
221
229
array([[0.42 , 0.071, 0.01 ]])
222
230
"""
223
231
if derivative :
224
- return 1 - x ** 2
225
- return np .tanh (x )
232
+ return 1 - input_array ** 2
233
+ return np .tanh (input_array )
226
234
227
- def softmax (self , x : np .ndarray ) -> np .ndarray :
235
+ def softmax (self , input_array : np .ndarray ) -> np .ndarray :
228
236
"""
229
237
Softmax activation function.
230
238
@@ -238,7 +246,7 @@ def softmax(self, x: np.ndarray) -> np.ndarray:
238
246
>>> np.round(output, 3)
239
247
array([0.09 , 0.245, 0.665])
240
248
"""
241
- exp_x = np .exp (x - np .max (x ))
249
+ exp_x = np .exp (input_array - np .max (input_array ))
242
250
return exp_x / exp_x .sum (axis = 0 )
243
251
244
252
def reset_network_state (self ) -> None :
@@ -270,17 +278,14 @@ def reset_network_state(self) -> None:
270
278
271
279
def forward_pass (self , inputs : list [np .ndarray ]) -> list [np .ndarray ]:
272
280
"""
273
- Perform forward propagation through the LSTM network.
281
+ Perform a forward pass through the LSTM network for the given inputs .
274
282
275
- :param inputs: The input data as a list of one-hot encoded vectors.
276
- :return: The outputs of the network.
277
- """
278
- """
279
- Forward pass through the LSTM network.
283
+ :param inputs: A list of input arrays (sequences).
284
+ :return: A list of network outputs.
280
285
281
- >>> lstm = LongShortTermMemory(input_data="abcde", hidden_layer_size=10,
282
- training_epochs=1, learning_rate=0.01 )
283
- >>> inputs = [lstm.one_hot_encode(char ) for char in lstm.input_sequence ]
286
+ Example:
287
+ >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10 )
288
+ >>> inputs = [np.random.rand(5, 1 ) for _ in range(5) ]
284
289
>>> outputs = lstm.forward_pass(inputs)
285
290
>>> len(outputs) == len(inputs)
286
291
True
@@ -326,6 +331,21 @@ def forward_pass(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
326
331
return outputs
327
332
328
333
def backward_pass (self , errors : list [np .ndarray ], inputs : list [np .ndarray ]) -> None :
334
+ """
335
+ Perform the backward pass for the LSTM model, adjusting weights and biases.
336
+
337
+ :param errors: A list of errors computed from the output layer.
338
+ :param inputs: A list of input one-hot encoded vectors.
339
+
340
+ Example:
341
+ >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
342
+ >>> inputs = [lstm.one_hot_encode(char) for char in lstm.input_sequence]
343
+ >>> predictions = lstm.forward_pass(inputs)
344
+ >>> errors = [-lstm.softmax(predictions[t]) for t in range(len(predictions))]
345
+ >>> for t in range(len(predictions)):
346
+ ... errors[t][lstm.char_to_index[lstm.target_sequence[t]]] += 1
347
+ >>> lstm.backward_pass(errors, inputs) # Should run without any errors
348
+ """
329
349
d_forget_gate_weights , d_forget_gate_bias = 0 , 0
330
350
d_input_gate_weights , d_input_gate_bias = 0 , 0
331
351
d_cell_candidate_weights , d_cell_candidate_bias = 0 , 0
@@ -422,6 +442,13 @@ def backward_pass(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> N
422
442
self .output_layer_bias += d_output_layer_bias * self .learning_rate
423
443
424
444
def train (self ) -> None :
445
+ """
446
+ Train the LSTM model.
447
+
448
+ Example:
449
+ >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
450
+ >>> lstm.train()
451
+ """
425
452
inputs = [self .one_hot_encode (char ) for char in self .input_sequence ]
426
453
427
454
for _ in range (self .training_epochs ):
@@ -434,12 +461,20 @@ def train(self) -> None:
434
461
435
462
self .backward_pass (errors , inputs )
436
463
437
- def test (self ):
464
+ def test (self ) -> None :
438
465
"""
439
466
Test the LSTM model.
440
467
441
468
Returns:
442
469
str: The output predictions.
470
+
471
+ Example:
472
+ >>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
473
+ >>> output = lstm.test()
474
+ >>> isinstance(output, str)
475
+ True
476
+ >>> len(output) == len(lstm.input_sequence)
477
+ True
443
478
"""
444
479
accuracy = 0
445
480
probabilities = self .forward_pass (
@@ -461,9 +496,9 @@ def test(self):
461
496
if prediction == self .target_sequence [t ]:
462
497
accuracy += 1
463
498
464
- print (f"Ground Truth:\n { self .target_sequence } \n " )
465
- print (f"Predictions:\n { output } \n " )
466
- print (f"Accuracy: { round (accuracy * 100 / len (self .input_sequence ), 2 )} %" )
499
+ # print(f"Ground Truth:\n{self.target_sequence}\n")
500
+ # print(f"Predictions:\n{output}\n")
501
+ # print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
467
502
468
503
return output
469
504
0 commit comments