Skip to content

Commit 91c8173

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 369a6b2 commit 91c8173

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

neural_network/lstm.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class LSTM:
4242
def __init__(self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: float = 0.05) -> None:
4343
"""
4444
Initialize the LSTM network with the given data and hyperparameters.
45-
45+
4646
:param data: The input data as a string.
4747
:param hidden_dim: The number of hidden units in the LSTM layer.
4848
:param epochs: The number of training epochs.
@@ -69,7 +69,7 @@ def __init__(self, data: str, hidden_dim: int = 25, epochs: int = 1000, lr: floa
6969
def one_hot_encode(self, char: str) -> np.ndarray:
7070
"""
7171
One-hot encode a character.
72-
72+
7373
:param char: The character to encode.
7474
:return: A one-hot encoded vector.
7575
"""
@@ -99,7 +99,7 @@ def initialize_weights(self) -> None:
9999
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
100100
"""
101101
Initialize weights with random values.
102-
102+
103103
:param input_dim: The input dimension.
104104
:param output_dim: The output dimension.
105105
:return: A matrix of initialized weights.
@@ -110,7 +110,7 @@ def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
110110
def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
111111
"""
112112
Sigmoid activation function.
113-
113+
114114
:param x: The input array.
115115
:param derivative: Whether to compute the derivative.
116116
:return: The sigmoid activation or its derivative.
@@ -122,7 +122,7 @@ def sigmoid(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
122122
def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
123123
"""
124124
Tanh activation function.
125-
125+
126126
:param x: The input array.
127127
:param derivative: Whether to compute the derivative.
128128
:return: The tanh activation or its derivative.
@@ -134,7 +134,7 @@ def tanh(self, x: np.ndarray, derivative: bool = False) -> np.ndarray:
134134
def softmax(self, x: np.ndarray) -> np.ndarray:
135135
"""
136136
Softmax activation function.
137-
137+
138138
:param x: The input array.
139139
:return: The softmax activation.
140140
"""
@@ -161,7 +161,7 @@ def reset(self) -> None:
161161
def forward(self, inputs: list) -> list:
162162
"""
163163
Perform forward propagation through the LSTM network.
164-
164+
165165
:param inputs: The input data as a list of one-hot encoded vectors.
166166
:return: The outputs of the network.
167167
"""
@@ -186,7 +186,7 @@ def forward(self, inputs: list) -> list:
186186
def backward(self, errors: list, inputs: list) -> None:
187187
"""
188188
Perform backpropagation through time to compute gradients and update weights.
189-
189+
190190
:param errors: The errors at each time step.
191191
:param inputs: The input data as a list of one-hot encoded vectors.
192192
"""
@@ -224,7 +224,7 @@ def backward(self, errors: list, inputs: list) -> None:
224224
d_i = d_cs * self.candidate_gates[t] * self.sigmoid(self.input_gates[t], derivative=True)
225225
d_wi += np.dot(d_i, inputs[t].T)
226226
d_bi += d_i
227-
227+
228228
# Candidate Gate Weights and Biases Errors
229229
d_c = d_cs * self.input_gates[t] * self.tanh(self.candidate_gates[t], derivative=True)
230230
d_wc += np.dot(d_c, inputs[t].T)
@@ -270,7 +270,7 @@ def train(self) -> None:
270270
errors[-1][self.char_to_idx[self.train_y[t]]] += 1
271271

272272
self.backward(errors, self.concat_inputs)
273-
273+
274274
def test(self) -> None:
275275
"""
276276
Test the trained LSTM network on the input data and print the accuracy.
@@ -289,7 +289,7 @@ def test(self) -> None:
289289

290290
print(f'Ground Truth:\n{self.train_y}\n')
291291
print(f'Predictions:\n{output}\n')
292-
292+
293293
print(f'Accuracy: {round(accuracy * 100 / len(self.train_X), 2)}%')
294294

295295
##### Data #####
@@ -314,4 +314,4 @@ def test(self) -> None:
314314
##### Testing #####
315315
# lstm.test()
316316

317-
# testing can be done by uncommenting the above lines of code.
317+
# testing can be done by uncommenting the above lines of code.

0 commit comments

Comments
 (0)