Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f0919fe

Browse files
committedOct 16, 2024·
written doc tests for backward pass and forward pass, fixed variable names in sigmoid function from x to input array
1 parent f3e974f commit f0919fe

File tree

1 file changed

+56
-21
lines changed

1 file changed

+56
-21
lines changed
 

‎neural_network/lstm.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,20 @@ def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
175175
:param input_dim: The input dimension.
176176
:param output_dim: The output dimension.
177177
:return: A matrix of initialized weights.
178+
179+
Example:
180+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
181+
>>> weights = lstm.init_weights(5, 10)
182+
>>> isinstance(weights, np.ndarray)
183+
True
184+
>>> weights.shape
185+
(10, 5)
178186
"""
179187
return self.random_generator.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
180188
6 / (input_dim + output_dim)
181189
)
182190

183-
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
191+
def sigmoid(self, input_array: np.ndarray, derivative: bool = False) -> np.ndarray:
184192
"""
185193
Sigmoid activation function.
186194
@@ -199,10 +207,10 @@ def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
199207
array([[0.197, 0.105, 0.045]])
200208
"""
201209
if derivative:
202-
return x * (1 - x)
203-
return 1 / (1 + np.exp(-x))
210+
return input_array * (1 - input_array)
211+
return 1 / (1 + np.exp(-input_array))
204212

205-
def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
213+
def tanh(self, input_array: np.ndarray, derivative: bool = False) -> np.ndarray:
206214
"""
207215
Tanh activation function.
208216
@@ -221,10 +229,10 @@ def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
221229
array([[0.42 , 0.071, 0.01 ]])
222230
"""
223231
if derivative:
224-
return 1 - x**2
225-
return np.tanh(x)
232+
return 1 - input_array**2
233+
return np.tanh(input_array)
226234

227-
def softmax(self, x: np.ndarray) -> np.ndarray:
235+
def softmax(self, input_array: np.ndarray) -> np.ndarray:
228236
"""
229237
Softmax activation function.
230238
@@ -238,7 +246,7 @@ def softmax(self, x: np.ndarray) -> np.ndarray:
238246
>>> np.round(output, 3)
239247
array([0.09 , 0.245, 0.665])
240248
"""
241-
exp_x = np.exp(x - np.max(x))
249+
exp_x = np.exp(input_array - np.max(input_array))
242250
return exp_x / exp_x.sum(axis=0)
243251

244252
def reset_network_state(self) -> None:
@@ -270,17 +278,14 @@ def reset_network_state(self) -> None:
270278

271279
def forward_pass(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
272280
"""
273-
Perform forward propagation through the LSTM network.
281+
Perform a forward pass through the LSTM network for the given inputs.
274282
275-
:param inputs: The input data as a list of one-hot encoded vectors.
276-
:return: The outputs of the network.
277-
"""
278-
"""
279-
Forward pass through the LSTM network.
283+
:param inputs: A list of input arrays (sequences).
284+
:return: A list of network outputs.
280285
281-
>>> lstm = LongShortTermMemory(input_data="abcde", hidden_layer_size=10,
282-
training_epochs=1, learning_rate=0.01)
283-
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.input_sequence]
286+
Example:
287+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
288+
>>> inputs = [np.random.rand(5, 1) for _ in range(5)]
284289
>>> outputs = lstm.forward_pass(inputs)
285290
>>> len(outputs) == len(inputs)
286291
True
@@ -326,6 +331,21 @@ def forward_pass(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
326331
return outputs
327332

328333
def backward_pass(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None:
334+
"""
335+
Perform the backward pass for the LSTM model, adjusting weights and biases.
336+
337+
:param errors: A list of errors computed from the output layer.
338+
:param inputs: A list of input one-hot encoded vectors.
339+
340+
Example:
341+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
342+
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.input_sequence]
343+
>>> predictions = lstm.forward_pass(inputs)
344+
>>> errors = [-lstm.softmax(predictions[t]) for t in range(len(predictions))]
345+
>>> for t in range(len(predictions)):
346+
... errors[t][lstm.char_to_index[lstm.target_sequence[t]]] += 1
347+
>>> lstm.backward_pass(errors, inputs) # Should run without any errors
348+
"""
329349
d_forget_gate_weights, d_forget_gate_bias = 0, 0
330350
d_input_gate_weights, d_input_gate_bias = 0, 0
331351
d_cell_candidate_weights, d_cell_candidate_bias = 0, 0
@@ -422,6 +442,13 @@ def backward_pass(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> N
422442
self.output_layer_bias += d_output_layer_bias * self.learning_rate
423443

424444
def train(self) -> None:
445+
"""
446+
Train the LSTM model.
447+
448+
Example:
449+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
450+
>>> lstm.train()
451+
"""
425452
inputs = [self.one_hot_encode(char) for char in self.input_sequence]
426453

427454
for _ in range(self.training_epochs):
@@ -434,12 +461,20 @@ def train(self) -> None:
434461

435462
self.backward_pass(errors, inputs)
436463

437-
def test(self):
464+
def test(self) -> None:
438465
"""
439466
Test the LSTM model.
440467
441468
Returns:
442469
str: The output predictions.
470+
471+
Example:
472+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10)
473+
>>> output = lstm.test()
474+
>>> isinstance(output, str)
475+
True
476+
>>> len(output) == len(lstm.input_sequence)
477+
True
443478
"""
444479
accuracy = 0
445480
probabilities = self.forward_pass(
@@ -461,9 +496,9 @@ def test(self):
461496
if prediction == self.target_sequence[t]:
462497
accuracy += 1
463498

464-
print(f"Ground Truth:\n{self.target_sequence}\n")
465-
print(f"Predictions:\n{output}\n")
466-
print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
499+
# print(f"Ground Truth:\n{self.target_sequence}\n")
500+
# print(f"Predictions:\n{output}\n")
501+
# print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
467502

468503
return output
469504

0 commit comments

Comments
 (0)
Please sign in to comment.