1
1
import torch
2
2
import torch .nn as nn
3
3
from segmentation_models_pytorch .base .modules import Activation
4
- from typing import Optional , Sequence , Union , Callable
4
+ from typing import Optional , Sequence , Union , Callable , Literal
5
5
6
6
7
- class ProjectionBlock (nn .Module ):
7
+ class ReadoutConcatBlock (nn .Module ):
8
8
"""
9
- Concatenates the cls tokens with the features to make use of the global information aggregated in the cls token.
10
- Projects the combined feature map to the original embedding dimension using a MLP
9
+ Concatenates the cls tokens with the features to make use of the global information aggregated in the prefix (cls) tokens.
10
+ Projects the combined feature map to the original embedding dimension using a MLP.
11
+
12
+ According to:
13
+ https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L79-L90
11
14
"""
12
15
13
- def __init__ (self , embed_dim : int , has_cls_token : bool ):
16
+ def __init__ (self , embed_dim : int , has_prefix_tokens : bool ):
14
17
super ().__init__ ()
15
- in_features = embed_dim * 2 if has_cls_token else embed_dim
18
+ in_features = embed_dim * 2 if has_prefix_tokens else embed_dim
16
19
out_features = embed_dim
17
20
self .project = nn .Sequential (
18
21
nn .Linear (in_features , out_features ),
19
22
nn .GELU (),
20
23
)
21
24
22
25
def forward (
23
- self , features : torch .Tensor , cls_token : Optional [torch .Tensor ] = None
26
+ self , features : torch .Tensor , prefix_tokens : Optional [torch .Tensor ] = None
24
27
) -> torch .Tensor :
25
28
batch_size , embed_dim , height , width = features .shape
26
29
27
30
# Rearrange to (batch_size, height * width, embed_dim)
28
31
features = features .view (batch_size , embed_dim , - 1 )
29
32
features = features .transpose (1 , 2 ).contiguous ()
30
33
31
- # Add CLS token
32
- if cls_token is not None :
33
- cls_token = cls_token .expand_as (features )
34
- features = torch .cat ([features , cls_token ], dim = 2 )
34
+ if prefix_tokens is not None :
35
+ # (batch_size, num_tokens, embed_dim) -> (batch_size, embed_dim)
36
+ prefix_tokens = prefix_tokens [:, 0 ] .expand_as (features )
37
+ features = torch .cat ([features , prefix_tokens ], dim = 2 )
35
38
36
39
# Project to embedding dimension
37
40
features = self .project (features )
@@ -43,6 +46,34 @@ def forward(
43
46
return features
44
47
45
48
49
+ class ReadoutAddBlock (nn .Module ):
50
+ """
51
+ Adds the prefix tokens to the features to make use of the global information aggregated in the prefix (cls) tokens.
52
+
53
+ According to:
54
+ https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L71-L76
55
+ """
56
+
57
+ def forward (
58
+ self , features : torch .Tensor , prefix_tokens : Optional [torch .Tensor ] = None
59
+ ) -> torch .Tensor :
60
+ if prefix_tokens is not None :
61
+ batch_size , embed_dim , height , width = features .shape
62
+ prefix_tokens = prefix_tokens .mean (dim = 1 )
63
+ prefix_tokens = prefix_tokens .view (batch_size , embed_dim , 1 , 1 )
64
+ features = features + prefix_tokens
65
+ return features
66
+
67
+
68
+ class ReadoutIgnoreBlock (nn .Module ):
69
+ """
70
+ Ignores the prefix tokens and returns the features as is.
71
+ """
72
+
73
+ def forward (self , features : torch .Tensor , * args , ** kwargs ) -> torch .Tensor :
74
+ return features
75
+
76
+
46
77
class ReassembleBlock (nn .Module ):
47
78
"""
48
79
Processes the features such that they have progressively increasing embedding size and progressively decreasing
@@ -182,20 +213,30 @@ def __init__(
182
213
self ,
183
214
encoder_out_channels : Sequence [int ] = (756 , 756 , 756 , 756 ),
184
215
encoder_output_strides : Sequence [int ] = (16 , 16 , 16 , 16 ),
216
+ encoder_has_prefix_tokens : bool = True ,
217
+ readout : Literal ["cat" , "add" , "ignore" ] = "cat" ,
185
218
intermediate_channels : Sequence [int ] = (256 , 512 , 1024 , 1024 ),
186
219
fusion_channels : int = 256 ,
187
- has_cls_token : bool = False ,
188
220
):
189
221
super ().__init__ ()
190
222
191
223
num_blocks = len (encoder_output_strides )
192
224
193
- # If encoder has cls token, then concatenate it with the features along the embedding dimension and project it
194
- # back to the feature_dim dimension. Else, ignore the non-existent cls token
195
- blocks = [
196
- ProjectionBlock (in_channels , has_cls_token )
197
- for in_channels in encoder_out_channels
198
- ]
225
+ # If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them
226
+ # according to the readout mode
227
+ if readout == "cat" :
228
+ blocks = [
229
+ ReadoutConcatBlock (in_channels , encoder_has_prefix_tokens )
230
+ for in_channels in encoder_out_channels
231
+ ]
232
+ elif readout == "add" :
233
+ blocks = [ReadoutAddBlock () for _ in encoder_out_channels ]
234
+ elif readout == "ignore" :
235
+ blocks = [ReadoutIgnoreBlock () for _ in encoder_out_channels ]
236
+ else :
237
+ raise ValueError (
238
+ f"Invalid readout mode: { readout } , should be one of: 'cat', 'add', 'ignore'"
239
+ )
199
240
self .projection_blocks = nn .ModuleList (blocks )
200
241
201
242
# Upsample factors to resize features to [1/4, 1/8, 1/16, 1/32, ...] scales
0 commit comments