Skip to content

Commit 0415717

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9e9a313 commit 0415717

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

neural_network/sliding_window_attention.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
33
Name - - sliding_window_attention.py
4-
Goal - - Implement a neural network architecture using sliding window attention for sequence
4+
Goal - - Implement a neural network architecture using sliding window attention for sequence
55
modeling tasks.
66
Detail: Total 5 layers neural network
77
* Input layer
@@ -12,9 +12,9 @@
1212
1313
Date: 2024.10.20
1414
References:
15-
1. Choromanska, A., et al. (2020). "On the Importance of Initialization and Momentum in
15+
1. Choromanska, A., et al. (2020). "On the Importance of Initialization and Momentum in
1616
Deep Learning." *Proceedings of the 37th International Conference on Machine Learning*.
17-
2. Dai, Z., et al. (2020). "Transformers are RNNs: Fast Autoregressive Transformers
17+
2. Dai, Z., et al. (2020). "Transformers are RNNs: Fast Autoregressive Transformers
1818
with Linear Attention." *arXiv preprint arXiv:2006.16236*.
1919
3. [Attention Mechanisms in Neural Networks](https://en.wikipedia.org/wiki/Attention_(machine_learning))
2020
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
@@ -52,7 +52,7 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
5252
Forward pass for the sliding window attention.
5353
5454
Args:
55-
input_tensor (np.ndarray): Input tensor of shape (batch_size, seq_length,
55+
input_tensor (np.ndarray): Input tensor of shape (batch_size, seq_length,
5656
embed_dim).
5757
5858
Returns:
@@ -93,7 +93,9 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
9393

9494
# usage
9595
rng = np.random.default_rng()
96-
x = rng.standard_normal((2, 10, 4)) # Batch size 2, sequence length 10, embedding dimension 4
96+
x = rng.standard_normal(
97+
(2, 10, 4)
98+
) # Batch size 2, sequence length 10, embedding dimension 4
9799
attention = SlidingWindowAttention(embed_dim=4, window_size=3)
98100
output = attention.forward(x)
99101
print(output)

0 commit comments

Comments
 (0)