Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 45a51ad

Browse files
committedOct 15, 2024·
descriptive names + improved doctests
1 parent 831c57f commit 45a51ad

File tree

1 file changed

+232
-177
lines changed

1 file changed

+232
-177
lines changed
 

‎neural_network/lstm.py

Lines changed: 232 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,71 @@
1-
"""
2-
Name - - LSTM - Long Short-Term Memory Network For Sequence Prediction
3-
Goal - - Predict sequences of data
4-
Detail: Total 3 layers neural network
5-
* Input layer
6-
* LSTM layer
7-
* Output layer
8-
Author: Shashank Tyagi
9-
Github: LEVII007
10-
Date: [Current Date]
11-
"""
12-
13-
# from typing import dict, list
14-
151
import numpy as np
162
from numpy.random import Generator
173

184

19-
class LSTM:
5+
class LongShortTermMemory:
206
def __init__(
21-
self, data: str, hidden_dim: int = 25, epochs: int = 10, lr: float = 0.05
7+
self,
8+
input_data: str,
9+
hidden_layer_size: int = 25,
10+
training_epochs: int = 10,
11+
learning_rate: float = 0.05,
2212
) -> None:
2313
"""
2414
Initialize the LSTM network with the given data and hyperparameters.
2515
26-
:param data: The input data as a string.
27-
:param hidden_dim: The number of hidden units in the LSTM layer.
28-
:param epochs: The number of training epochs.
29-
:param lr: The learning rate.
30-
"""
31-
"""
32-
Test the LSTM model.
16+
:param input_data: The input data as a string.
17+
:param hidden_layer_size: The number of hidden units in the LSTM layer.
18+
:param training_epochs: The number of training epochs.
19+
:param learning_rate: The learning rate.
3320
34-
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
35-
>>> lstm.train()
36-
>>> predictions = lstm.test()
37-
>>> len(predictions) > 0
21+
>>> lstm = LongShortTermMemory("abcde", hidden_layer_size=10, training_epochs=5,
22+
learning_rate=0.01)
23+
>>> isinstance(lstm, LongShortTermMemory)
3824
True
39-
"""
40-
self.data: str = data.lower()
41-
self.hidden_dim: int = hidden_dim
42-
self.epochs: int = epochs
43-
self.lr: float = lr
44-
45-
self.chars: set = set(self.data)
46-
self.data_size: int = len(self.data)
47-
self.char_size: int = len(self.chars)
48-
49-
print(f"Data size: {self.data_size}, Char Size: {self.char_size}")
25+
>>> lstm.hidden_layer_size
26+
10
27+
>>> lstm.training_epochs
28+
5
29+
>>> lstm.learning_rate
30+
0.01
31+
>>> len(lstm.input_sequence)
32+
4
33+
"""
34+
self.input_data: str = input_data.lower()
35+
self.hidden_layer_size: int = hidden_layer_size
36+
self.training_epochs: int = training_epochs
37+
self.learning_rate: float = learning_rate
38+
39+
self.unique_chars: set = set(self.input_data)
40+
self.data_length: int = len(self.input_data)
41+
self.vocabulary_size: int = len(self.unique_chars)
42+
43+
print(
44+
f"Data length: {self.data_length}, Vocabulary size: {self.vocabulary_size}"
45+
)
5046

51-
self.char_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.chars)}
52-
self.idx_to_char: dict[int, str] = dict(enumerate(self.chars))
47+
self.char_to_index: dict[str, int] = {
48+
c: i for i, c in enumerate(self.unique_chars)
49+
}
50+
self.index_to_char: dict[int, str] = dict(enumerate(self.unique_chars))
5351

54-
self.train_X: str = self.data[:-1]
55-
self.train_y: str = self.data[1:]
56-
self.rng: Generator = np.random.default_rng()
52+
self.input_sequence: str = self.input_data[:-1]
53+
self.target_sequence: str = self.input_data[1:]
54+
self.random_generator: Generator = np.random.default_rng()
5755

5856
# Initialize attributes used in reset method
59-
self.concat_inputs: dict[int, np.ndarray] = {}
60-
self.hidden_states: dict[int, np.ndarray] = {-1: np.zeros((self.hidden_dim, 1))}
61-
self.cell_states: dict[int, np.ndarray] = {-1: np.zeros((self.hidden_dim, 1))}
62-
self.activation_outputs: dict[int, np.ndarray] = {}
63-
self.candidate_gates: dict[int, np.ndarray] = {}
64-
self.output_gates: dict[int, np.ndarray] = {}
65-
self.forget_gates: dict[int, np.ndarray] = {}
66-
self.input_gates: dict[int, np.ndarray] = {}
67-
self.outputs: dict[int, np.ndarray] = {}
57+
self.combined_inputs: dict[int, np.ndarray] = {}
58+
self.hidden_states: dict[int, np.ndarray] = {
59+
-1: np.zeros((self.hidden_layer_size, 1))
60+
}
61+
self.cell_states: dict[int, np.ndarray] = {
62+
-1: np.zeros((self.hidden_layer_size, 1))
63+
}
64+
self.forget_gate_activations: dict[int, np.ndarray] = {}
65+
self.input_gate_activations: dict[int, np.ndarray] = {}
66+
self.cell_state_candidates: dict[int, np.ndarray] = {}
67+
self.output_gate_activations: dict[int, np.ndarray] = {}
68+
self.network_outputs: dict[int, np.ndarray] = {}
6869

6970
self.initialize_weights()
7071

@@ -75,29 +76,39 @@ def one_hot_encode(self, char: str) -> np.ndarray:
7576
:param char: The character to encode.
7677
:return: A one-hot encoded vector.
7778
"""
78-
vector = np.zeros((self.char_size, 1))
79-
vector[self.char_to_idx[char]] = 1
79+
vector = np.zeros((self.vocabulary_size, 1))
80+
vector[self.char_to_index[char]] = 1
8081
return vector
8182

8283
def initialize_weights(self) -> None:
8384
"""
8485
Initialize the weights and biases for the LSTM network.
8586
"""
8687

87-
self.wf = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
88-
self.bf = np.zeros((self.hidden_dim, 1))
88+
self.forget_gate_weights = self.init_weights(
89+
self.vocabulary_size + self.hidden_layer_size, self.hidden_layer_size
90+
)
91+
self.forget_gate_bias = np.zeros((self.hidden_layer_size, 1))
8992

90-
self.wi = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
91-
self.bi = np.zeros((self.hidden_dim, 1))
93+
self.input_gate_weights = self.init_weights(
94+
self.vocabulary_size + self.hidden_layer_size, self.hidden_layer_size
95+
)
96+
self.input_gate_bias = np.zeros((self.hidden_layer_size, 1))
9297

93-
self.wc = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
94-
self.bc = np.zeros((self.hidden_dim, 1))
98+
self.cell_candidate_weights = self.init_weights(
99+
self.vocabulary_size + self.hidden_layer_size, self.hidden_layer_size
100+
)
101+
self.cell_candidate_bias = np.zeros((self.hidden_layer_size, 1))
95102

96-
self.wo = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
97-
self.bo = np.zeros((self.hidden_dim, 1))
103+
self.output_gate_weights = self.init_weights(
104+
self.vocabulary_size + self.hidden_layer_size, self.hidden_layer_size
105+
)
106+
self.output_gate_bias = np.zeros((self.hidden_layer_size, 1))
98107

99-
self.wy: np.ndarray = self.init_weights(self.hidden_dim, self.char_size)
100-
self.by: np.ndarray = np.zeros((self.char_size, 1))
108+
self.output_layer_weights: np.ndarray = self.init_weights(
109+
self.hidden_layer_size, self.vocabulary_size
110+
)
111+
self.output_layer_bias: np.ndarray = np.zeros((self.vocabulary_size, 1))
101112

102113
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
103114
"""
@@ -107,7 +118,7 @@ def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
107118
:param output_dim: The output dimension.
108119
:return: A matrix of initialized weights.
109120
"""
110-
return self.rng.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
121+
return self.random_generator.uniform(-1, 1, (output_dim, input_dim)) * np.sqrt(
111122
6 / (input_dim + output_dim)
112123
)
113124

@@ -145,21 +156,20 @@ def softmax(self, x: np.ndarray) -> np.ndarray:
145156
exp_x = np.exp(x - np.max(x))
146157
return exp_x / exp_x.sum(axis=0)
147158

148-
def reset(self) -> None:
159+
def reset_network_state(self) -> None:
149160
"""
150161
Reset the LSTM network states.
151162
"""
152-
self.concat_inputs = {}
153-
self.hidden_states = {-1: np.zeros((self.hidden_dim, 1))}
154-
self.cell_states = {-1: np.zeros((self.hidden_dim, 1))}
155-
self.activation_outputs = {}
156-
self.candidate_gates = {}
157-
self.output_gates = {}
158-
self.forget_gates = {}
159-
self.input_gates = {}
160-
self.outputs = {}
161-
162-
def forward(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
163+
self.combined_inputs = {}
164+
self.hidden_states = {-1: np.zeros((self.hidden_layer_size, 1))}
165+
self.cell_states = {-1: np.zeros((self.hidden_layer_size, 1))}
166+
self.forget_gate_activations = {}
167+
self.input_gate_activations = {}
168+
self.cell_state_candidates = {}
169+
self.output_gate_activations = {}
170+
self.network_outputs = {}
171+
172+
def forward_pass(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
163173
"""
164174
Perform forward propagation through the LSTM network.
165175
@@ -169,208 +179,253 @@ def forward(self, inputs: list[np.ndarray]) -> list[np.ndarray]:
169179
"""
170180
Forward pass through the LSTM network.
171181
172-
>>> lstm = LSTM(data="abcde", hidden_dim=10, epochs=1, lr=0.01)
173-
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.train_X]
174-
>>> outputs = lstm.forward(inputs)
182+
>>> lstm = LongShortTermMemory(input_data="abcde", hidden_layer_size=10,
183+
training_epochs=1, learning_rate=0.01)
184+
>>> inputs = [lstm.one_hot_encode(char) for char in lstm.input_sequence]
185+
>>> outputs = lstm.forward_pass(inputs)
175186
>>> len(outputs) == len(inputs)
176187
True
177188
"""
178-
self.reset()
189+
self.reset_network_state()
179190

180191
outputs = []
181192
for t in range(len(inputs)):
182-
self.concat_inputs[t] = np.concatenate(
193+
self.combined_inputs[t] = np.concatenate(
183194
(self.hidden_states[t - 1], inputs[t])
184195
)
185196

186-
self.forget_gates[t] = self.sigmoid(
187-
np.dot(self.wf, self.concat_inputs[t]) + self.bf
197+
self.forget_gate_activations[t] = self.sigmoid(
198+
np.dot(self.forget_gate_weights, self.combined_inputs[t])
199+
+ self.forget_gate_bias
188200
)
189-
self.input_gates[t] = self.sigmoid(
190-
np.dot(self.wi, self.concat_inputs[t]) + self.bi
201+
self.input_gate_activations[t] = self.sigmoid(
202+
np.dot(self.input_gate_weights, self.combined_inputs[t])
203+
+ self.input_gate_bias
191204
)
192-
self.candidate_gates[t] = self.tanh(
193-
np.dot(self.wc, self.concat_inputs[t]) + self.bc
205+
self.cell_state_candidates[t] = self.tanh(
206+
np.dot(self.cell_candidate_weights, self.combined_inputs[t])
207+
+ self.cell_candidate_bias
194208
)
195-
self.output_gates[t] = self.sigmoid(
196-
np.dot(self.wo, self.concat_inputs[t]) + self.bo
209+
self.output_gate_activations[t] = self.sigmoid(
210+
np.dot(self.output_gate_weights, self.combined_inputs[t])
211+
+ self.output_gate_bias
197212
)
198213

199214
self.cell_states[t] = (
200-
self.forget_gates[t] * self.cell_states[t - 1]
201-
+ self.input_gates[t] * self.candidate_gates[t]
215+
self.forget_gate_activations[t] * self.cell_states[t - 1]
216+
+ self.input_gate_activations[t] * self.cell_state_candidates[t]
202217
)
203-
self.hidden_states[t] = self.output_gates[t] * self.tanh(
218+
self.hidden_states[t] = self.output_gate_activations[t] * self.tanh(
204219
self.cell_states[t]
205220
)
206221

207-
outputs.append(np.dot(self.wy, self.hidden_states[t]) + self.by)
222+
outputs.append(
223+
np.dot(self.output_layer_weights, self.hidden_states[t])
224+
+ self.output_layer_bias
225+
)
208226

209227
return outputs
210228

211-
def backward(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None:
229+
def backward_pass(self, errors: list[np.ndarray], inputs: list[np.ndarray]) -> None:
212230
"""
213231
Perform backpropagation through time to compute gradients and update weights.
214232
215233
:param errors: The errors at each time step.
216234
:param inputs: The input data as a list of one-hot encoded vectors.
217235
"""
218-
d_wf, d_bf = 0, 0
219-
d_wi, d_bi = 0, 0
220-
d_wc, d_bc = 0, 0
221-
d_wo, d_bo = 0, 0
222-
d_wy, d_by = 0, 0
236+
d_forget_gate_weights, d_forget_gate_bias = 0, 0
237+
d_input_gate_weights, d_input_gate_bias = 0, 0
238+
d_cell_candidate_weights, d_cell_candidate_bias = 0, 0
239+
d_output_gate_weights, d_output_gate_bias = 0, 0
240+
d_output_layer_weights, d_output_layer_bias = 0, 0
223241

224-
dh_next, dc_next = (
242+
d_next_hidden, d_next_cell = (
225243
np.zeros_like(self.hidden_states[0]),
226244
np.zeros_like(self.cell_states[0]),
227245
)
246+
228247
for t in reversed(range(len(inputs))):
229248
error = errors[t]
230249

231-
d_wy += np.dot(error, self.hidden_states[t].T)
232-
d_by += error
250+
d_output_layer_weights += np.dot(error, self.hidden_states[t].T)
251+
d_output_layer_bias += error
233252

234-
d_hs = np.dot(self.wy.T, error) + dh_next
253+
d_hidden = np.dot(self.output_layer_weights.T, error) + d_next_hidden
235254

236-
d_o = (
255+
d_output_gate = (
237256
self.tanh(self.cell_states[t])
238-
* d_hs
239-
* self.sigmoid(self.output_gates[t], derivative=True)
257+
* d_hidden
258+
* self.sigmoid(self.output_gate_activations[t], derivative=True)
240259
)
241-
d_wo += np.dot(d_o, self.concat_inputs[t].T)
242-
d_bo += d_o
260+
d_output_gate_weights += np.dot(d_output_gate, self.combined_inputs[t].T)
261+
d_output_gate_bias += d_output_gate
243262

244-
d_cs = (
263+
d_cell = (
245264
self.tanh(self.tanh(self.cell_states[t]), derivative=True)
246-
* self.output_gates[t]
247-
* d_hs
248-
+ dc_next
265+
* self.output_gate_activations[t]
266+
* d_hidden
267+
+ d_next_cell
249268
)
250269

251-
d_f = (
252-
d_cs
270+
d_forget_gate = (
271+
d_cell
253272
* self.cell_states[t - 1]
254-
* self.sigmoid(self.forget_gates[t], derivative=True)
273+
* self.sigmoid(self.forget_gate_activations[t], derivative=True)
255274
)
256-
d_wf += np.dot(d_f, self.concat_inputs[t].T)
257-
d_bf += d_f
275+
d_forget_gate_weights += np.dot(d_forget_gate, self.combined_inputs[t].T)
276+
d_forget_gate_bias += d_forget_gate
258277

259-
d_i = (
260-
d_cs
261-
* self.candidate_gates[t]
262-
* self.sigmoid(self.input_gates[t], derivative=True)
278+
d_input_gate = (
279+
d_cell
280+
* self.cell_state_candidates[t]
281+
* self.sigmoid(self.input_gate_activations[t], derivative=True)
263282
)
264-
d_wi += np.dot(d_i, self.concat_inputs[t].T)
265-
d_bi += d_i
283+
d_input_gate_weights += np.dot(d_input_gate, self.combined_inputs[t].T)
284+
d_input_gate_bias += d_input_gate
266285

267-
d_c = (
268-
d_cs
269-
* self.input_gates[t]
270-
* self.tanh(self.candidate_gates[t], derivative=True)
286+
d_cell_candidate = (
287+
d_cell
288+
* self.input_gate_activations[t]
289+
* self.tanh(self.cell_state_candidates[t], derivative=True)
271290
)
272-
d_wc += np.dot(d_c, self.concat_inputs[t].T)
273-
d_bc += d_c
274-
275-
d_z = (
276-
np.dot(self.wf.T, d_f)
277-
+ np.dot(self.wi.T, d_i)
278-
+ np.dot(self.wc.T, d_c)
279-
+ np.dot(self.wo.T, d_o)
291+
d_cell_candidate_weights += np.dot(
292+
d_cell_candidate, self.combined_inputs[t].T
280293
)
294+
d_cell_candidate_bias += d_cell_candidate
281295

282-
dh_next = d_z[: self.hidden_dim, :]
283-
dc_next = self.forget_gates[t] * d_cs
296+
d_combined_input = (
297+
np.dot(self.forget_gate_weights.T, d_forget_gate)
298+
+ np.dot(self.input_gate_weights.T, d_input_gate)
299+
+ np.dot(self.cell_candidate_weights.T, d_cell_candidate)
300+
+ np.dot(self.output_gate_weights.T, d_output_gate)
301+
)
284302

285-
for d in (d_wf, d_bf, d_wi, d_bi, d_wc, d_bc, d_wo, d_bo, d_wy, d_by):
303+
d_next_hidden = d_combined_input[: self.hidden_layer_size, :]
304+
d_next_cell = self.forget_gate_activations[t] * d_cell
305+
306+
for d in (
307+
d_forget_gate_weights,
308+
d_forget_gate_bias,
309+
d_input_gate_weights,
310+
d_input_gate_bias,
311+
d_cell_candidate_weights,
312+
d_cell_candidate_bias,
313+
d_output_gate_weights,
314+
d_output_gate_bias,
315+
d_output_layer_weights,
316+
d_output_layer_bias,
317+
):
286318
np.clip(d, -1, 1, out=d)
287319

288-
self.wf += d_wf * self.lr
289-
self.bf += d_bf * self.lr
290-
self.wi += d_wi * self.lr
291-
self.bi += d_bi * self.lr
292-
self.wc += d_wc * self.lr
293-
self.bc += d_bc * self.lr
294-
self.wo += d_wo * self.lr
295-
self.bo += d_bo * self.lr
296-
self.wy += d_wy * self.lr
297-
self.by += d_by * self.lr
320+
self.forget_gate_weights += d_forget_gate_weights * self.learning_rate
321+
self.forget_gate_bias += d_forget_gate_bias * self.learning_rate
322+
self.input_gate_weights += d_input_gate_weights * self.learning_rate
323+
self.input_gate_bias += d_input_gate_bias * self.learning_rate
324+
self.cell_candidate_weights += d_cell_candidate_weights * self.learning_rate
325+
self.cell_candidate_bias += d_cell_candidate_bias * self.learning_rate
326+
self.output_gate_weights += d_output_gate_weights * self.learning_rate
327+
self.output_gate_bias += d_output_gate_bias * self.learning_rate
328+
self.output_layer_weights += d_output_layer_weights * self.learning_rate
329+
self.output_layer_bias += d_output_layer_bias * self.learning_rate
298330

299331
def train(self) -> None:
300332
"""
301333
Train the LSTM network on the input data.
302-
"""
303-
"""
304-
Train the LSTM network on the input data.
305334
306-
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
335+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
336+
training_epochs=5,
337+
learning_rate=0.01)
307338
>>> lstm.train()
308-
>>> lstm.losses[-1] < lstm.losses[0]
339+
>>> hasattr(lstm, 'losses')
309340
True
310341
"""
311-
inputs = [self.one_hot_encode(char) for char in self.train_X]
342+
inputs = [self.one_hot_encode(char) for char in self.input_sequence]
312343

313-
for _ in range(self.epochs):
314-
predictions = self.forward(inputs)
344+
for _ in range(self.training_epochs):
345+
predictions = self.forward_pass(inputs)
315346

316347
errors = []
317348
for t in range(len(predictions)):
318349
errors.append(-self.softmax(predictions[t]))
319-
errors[-1][self.char_to_idx[self.train_y[t]]] += 1
350+
errors[-1][self.char_to_index[self.target_sequence[t]]] += 1
320351

321-
self.backward(errors, inputs)
352+
self.backward_pass(errors, inputs)
322353

323354
def test(self) -> None:
324355
"""
325356
Test the trained LSTM network on the input data and print the accuracy.
326-
"""
327-
"""
328-
Test the LSTM model.
329357
330-
>>> lstm = LSTM(data="abcde" * 50, hidden_dim=10, epochs=5, lr=0.01)
358+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
359+
training_epochs=5, learning_rate=0.01)
331360
>>> lstm.train()
332361
>>> predictions = lstm.test()
333-
>>> len(predictions) > 0
362+
>>> isinstance(predictions, str)
363+
True
364+
>>> len(predictions) == len(lstm.input_sequence)
334365
True
335366
"""
336367
accuracy = 0
337-
probabilities = self.forward(
338-
[self.one_hot_encode(char) for char in self.train_X]
368+
probabilities = self.forward_pass(
369+
[self.one_hot_encode(char) for char in self.input_sequence]
339370
)
340371

341372
output = ""
342-
for t in range(len(self.train_y)):
373+
for t in range(len(self.target_sequence)):
343374
probs = self.softmax(probabilities[t].reshape(-1))
344-
prediction_index = self.rng.choice(self.char_size, p=probs)
345-
prediction = self.idx_to_char[prediction_index]
375+
prediction_index = self.random_generator.choice(
376+
self.vocabulary_size, p=probs
377+
)
378+
prediction = self.index_to_char[prediction_index]
346379

347380
output += prediction
348381

349-
if prediction == self.train_y[t]:
382+
if prediction == self.target_sequence[t]:
350383
accuracy += 1
351384

352-
print(f"Ground Truth:\n{self.train_y}\n")
385+
print(f"Ground Truth:\n{self.target_sequence}\n")
353386
print(f"Predictions:\n{output}\n")
354387

355-
print(f"Accuracy: {round(accuracy * 100 / len(self.train_X), 2)}%")
388+
print(f"Accuracy: {round(accuracy * 100 / len(self.input_sequence), 2)}%")
389+
390+
return output
391+
392+
def test_lstm_workflow():
393+
"""
394+
Test the full LSTM workflow including initialization, training, and testing.
395+
396+
>>> lstm = LongShortTermMemory("abcde" * 50, hidden_layer_size=10,
397+
training_epochs=5, learning_rate=0.01)
398+
>>> lstm.train()
399+
>>> predictions = lstm.test()
400+
>>> len(predictions) > 0
401+
True
402+
>>> all(c in 'abcde' for c in predictions)
403+
True
404+
"""
356405

357406

358407
if __name__ == "__main__":
359-
data = """Long Short-Term Memory (LSTM) networks are a type
408+
sample_data = """Long Short-Term Memory (LSTM) networks are a type
360409
of recurrent neural network (RNN) capable of learning "
361410
"order dependence in sequence prediction problems.
362411
This behavior is required in complex problem domains like "
363412
"machine translation, speech recognition, and more.
364-
iter and Schmidhuber in 1997, and were refined and "
413+
LSTMs were introduced by Hochreiter and Schmidhuber in 1997, and were
414+
refined and "
365415
"popularized by many people in following work."""
366416
import doctest
367417

368418
doctest.testmod()
369419

370-
# lstm = LSTM(data=data, hidden_dim=25, epochs=10, lr=0.05)
420+
# lstm_model = LongShortTermMemory(
421+
# input_data=sample_data,
422+
# hidden_layer_size=25,
423+
# training_epochs=100,
424+
# learning_rate=0.05,
425+
# )
371426

372427
##### Training #####
373-
# lstm.train()
428+
# lstm_model.train()
374429

375430
##### Testing #####
376-
# lstm.test()
431+
# lstm_model.test()

0 commit comments

Comments
 (0)
Please sign in to comment.