|
1 | 1 | """
|
2 | 2 | - - - - - -- - - - - - - - - - - - - - - - - - - - - - -
|
3 | 3 | Name - - sliding_window_attention.py
|
4 |
| -Goal - - Implement a neural network architecture using sliding window attention for sequence |
| 4 | +Goal - - Implement a neural network architecture using sliding window attention for sequence |
5 | 5 | modeling tasks.
|
6 | 6 | Detail: Total 5 layers neural network
|
7 | 7 | * Input layer
|
|
12 | 12 |
|
13 | 13 | Date: 2024.10.20
|
14 | 14 | References:
|
15 |
| - 1. Choromanska, A., et al. (2020). "On the Importance of Initialization and Momentum in |
| 15 | + 1. Choromanska, A., et al. (2020). "On the Importance of Initialization and Momentum in |
16 | 16 | Deep Learning." *Proceedings of the 37th International Conference on Machine Learning*.
|
17 |
| - 2. Dai, Z., et al. (2020). "Transformers are RNNs: Fast Autoregressive Transformers |
| 17 | + 2. Dai, Z., et al. (2020). "Transformers are RNNs: Fast Autoregressive Transformers |
18 | 18 | with Linear Attention." *arXiv preprint arXiv:2006.16236*.
|
19 | 19 | 3. [Attention Mechanisms in Neural Networks](https://en.wikipedia.org/wiki/Attention_(machine_learning))
|
20 | 20 | - - - - - -- - - - - - - - - - - - - - - - - - - - - - -
|
@@ -52,7 +52,7 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
|
52 | 52 | Forward pass for the sliding window attention.
|
53 | 53 |
|
54 | 54 | Args:
|
55 |
| - input_tensor (np.ndarray): Input tensor of shape (batch_size, seq_length, |
| 55 | + input_tensor (np.ndarray): Input tensor of shape (batch_size, seq_length, |
56 | 56 | embed_dim).
|
57 | 57 |
|
58 | 58 | Returns:
|
@@ -93,7 +93,9 @@ def forward(self, input_tensor: np.ndarray) -> np.ndarray:
|
93 | 93 |
|
94 | 94 | # usage
|
95 | 95 | rng = np.random.default_rng()
|
96 |
| - x = rng.standard_normal((2, 10, 4)) # Batch size 2, sequence length 10, embedding dimension 4 |
| 96 | + x = rng.standard_normal( |
| 97 | + (2, 10, 4) |
| 98 | + ) # Batch size 2, sequence length 10, embedding dimension 4 |
97 | 99 | attention = SlidingWindowAttention(embed_dim=4, window_size=3)
|
98 | 100 | output = attention.forward(x)
|
99 | 101 | print(output)
|
0 commit comments