@@ -42,7 +42,7 @@ class LSTM:
42
42
def __init__ (self , data : str , hidden_dim : int = 25 , epochs : int = 1000 , lr : float = 0.05 ) -> None :
43
43
"""
44
44
Initialize the LSTM network with the given data and hyperparameters.
45
-
45
+
46
46
:param data: The input data as a string.
47
47
:param hidden_dim: The number of hidden units in the LSTM layer.
48
48
:param epochs: The number of training epochs.
@@ -69,7 +69,7 @@ def __init__(self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: floa
69
69
def one_hot_encode (self , char : str ) -> np .ndarray :
70
70
"""
71
71
One-hot encode a character.
72
-
72
+
73
73
:param char: The character to encode.
74
74
:return: A one-hot encoded vector.
75
75
"""
@@ -99,7 +99,7 @@ def initialize_weights(self) -> None:
99
99
def init_weights (self , input_dim : int , output_dim : int ) -> np .ndarray :
100
100
"""
101
101
Initialize weights with random values.
102
-
102
+
103
103
:param input_dim: The input dimension.
104
104
:param output_dim: The output dimension.
105
105
:return: A matrix of initialized weights.
@@ -110,7 +110,7 @@ def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
110
110
def sigmoid (self , x : np .ndarray , derivative : bool = False ) -> np .ndarray :
111
111
"""
112
112
Sigmoid activation function.
113
-
113
+
114
114
:param x: The input array.
115
115
:param derivative: Whether to compute the derivative.
116
116
:return: The sigmoid activation or its derivative.
@@ -122,7 +122,7 @@ def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
122
122
def tanh (self , x : np .ndarray , derivative : bool = False ) -> np .ndarray :
123
123
"""
124
124
Tanh activation function.
125
-
125
+
126
126
:param x: The input array.
127
127
:param derivative: Whether to compute the derivative.
128
128
:return: The tanh activation or its derivative.
@@ -134,7 +134,7 @@ def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
134
134
def softmax (self , x : np .ndarray ) -> np .ndarray :
135
135
"""
136
136
Softmax activation function.
137
-
137
+
138
138
:param x: The input array.
139
139
:return: The softmax activation.
140
140
"""
@@ -161,7 +161,7 @@ def reset(self) -> None:
161
161
def forward (self , inputs : list ) -> list :
162
162
"""
163
163
Perform forward propagation through the LSTM network.
164
-
164
+
165
165
:param inputs: The input data as a list of one-hot encoded vectors.
166
166
:return: The outputs of the network.
167
167
"""
@@ -186,7 +186,7 @@ def forward(self, inputs: list) -> list:
186
186
def backward (self , errors : list , inputs : list ) -> None :
187
187
"""
188
188
Perform backpropagation through time to compute gradients and update weights.
189
-
189
+
190
190
:param errors: The errors at each time step.
191
191
:param inputs: The input data as a list of one-hot encoded vectors.
192
192
"""
@@ -224,7 +224,7 @@ def backward(self, errors: list, inputs: list) -> None:
224
224
d_i = d_cs * self .candidate_gates [t ] * self .sigmoid (self .input_gates [t ], derivative = True )
225
225
d_wi += np .dot (d_i , inputs [t ].T )
226
226
d_bi += d_i
227
-
227
+
228
228
# Candidate Gate Weights and Biases Errors
229
229
d_c = d_cs * self .input_gates [t ] * self .tanh (self .candidate_gates [t ], derivative = True )
230
230
d_wc += np .dot (d_c , inputs [t ].T )
@@ -270,7 +270,7 @@ def train(self) -> None:
270
270
errors [- 1 ][self .char_to_idx [self .train_y [t ]]] += 1
271
271
272
272
self .backward (errors , self .concat_inputs )
273
-
273
+
274
274
def test (self ) -> None :
275
275
"""
276
276
Test the trained LSTM network on the input data and print the accuracy.
@@ -289,7 +289,7 @@ def test(self) -> None:
289
289
290
290
print (f'Ground Truth:\n { self .train_y } \n ' )
291
291
print (f'Predictions:\n { output } \n ' )
292
-
292
+
293
293
print (f'Accuracy: { round (accuracy * 100 / len (self .train_X ), 2 )} %' )
294
294
295
295
##### Data #####
@@ -314,4 +314,4 @@ def test(self) -> None:
314
314
##### Testing #####
315
315
# lstm.test()
316
316
317
- # testing can be done by uncommenting the above lines of code.
317
+ # testing can be done by uncommenting the above lines of code.
0 commit comments