Skip to content

Commit 8d3ed4f

Browse files
committed
Add decoder_readout according to initial impl
1 parent 21a164a commit 8d3ed4f

File tree

3 files changed

+91
-34
lines changed

3 files changed

+91
-34
lines changed

segmentation_models_pytorch/decoders/dpt/decoder.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,40 @@
11
import torch
22
import torch.nn as nn
33
from segmentation_models_pytorch.base.modules import Activation
4-
from typing import Optional, Sequence, Union, Callable
4+
from typing import Optional, Sequence, Union, Callable, Literal
55

66

7-
class ProjectionBlock(nn.Module):
7+
class ReadoutConcatBlock(nn.Module):
88
"""
9-
Concatenates the cls tokens with the features to make use of the global information aggregated in the cls token.
10-
Projects the combined feature map to the original embedding dimension using a MLP
9+
Concatenates the cls tokens with the features to make use of the global information aggregated in the prefix (cls) tokens.
10+
Projects the combined feature map to the original embedding dimension using a MLP.
11+
12+
According to:
13+
https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L79-L90
1114
"""
1215

13-
def __init__(self, embed_dim: int, has_cls_token: bool):
16+
def __init__(self, embed_dim: int, has_prefix_tokens: bool):
1417
super().__init__()
15-
in_features = embed_dim * 2 if has_cls_token else embed_dim
18+
in_features = embed_dim * 2 if has_prefix_tokens else embed_dim
1619
out_features = embed_dim
1720
self.project = nn.Sequential(
1821
nn.Linear(in_features, out_features),
1922
nn.GELU(),
2023
)
2124

2225
def forward(
23-
self, features: torch.Tensor, cls_token: Optional[torch.Tensor] = None
26+
self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None
2427
) -> torch.Tensor:
2528
batch_size, embed_dim, height, width = features.shape
2629

2730
# Rearrange to (batch_size, height * width, embed_dim)
2831
features = features.view(batch_size, embed_dim, -1)
2932
features = features.transpose(1, 2).contiguous()
3033

31-
# Add CLS token
32-
if cls_token is not None:
33-
cls_token = cls_token.expand_as(features)
34-
features = torch.cat([features, cls_token], dim=2)
34+
if prefix_tokens is not None:
35+
# (batch_size, num_tokens, embed_dim) -> (batch_size, embed_dim)
36+
prefix_tokens = prefix_tokens[:, 0].expand_as(features)
37+
features = torch.cat([features, prefix_tokens], dim=2)
3538

3639
# Project to embedding dimension
3740
features = self.project(features)
@@ -43,6 +46,34 @@ def forward(
4346
return features
4447

4548

49+
class ReadoutAddBlock(nn.Module):
50+
"""
51+
Adds the prefix tokens to the features to make use of the global information aggregated in the prefix (cls) tokens.
52+
53+
According to:
54+
https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L71-L76
55+
"""
56+
57+
def forward(
58+
self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None
59+
) -> torch.Tensor:
60+
if prefix_tokens is not None:
61+
batch_size, embed_dim, height, width = features.shape
62+
prefix_tokens = prefix_tokens.mean(dim=1)
63+
prefix_tokens = prefix_tokens.view(batch_size, embed_dim, 1, 1)
64+
features = features + prefix_tokens
65+
return features
66+
67+
68+
class ReadoutIgnoreBlock(nn.Module):
69+
"""
70+
Ignores the prefix tokens and returns the features as is.
71+
"""
72+
73+
def forward(self, features: torch.Tensor, *args, **kwargs) -> torch.Tensor:
74+
return features
75+
76+
4677
class ReassembleBlock(nn.Module):
4778
"""
4879
Processes the features such that they have progressively increasing embedding size and progressively decreasing
@@ -182,20 +213,30 @@ def __init__(
182213
self,
183214
encoder_out_channels: Sequence[int] = (756, 756, 756, 756),
184215
encoder_output_strides: Sequence[int] = (16, 16, 16, 16),
216+
encoder_has_prefix_tokens: bool = True,
217+
readout: Literal["cat", "add", "ignore"] = "cat",
185218
intermediate_channels: Sequence[int] = (256, 512, 1024, 1024),
186219
fusion_channels: int = 256,
187-
has_cls_token: bool = False,
188220
):
189221
super().__init__()
190222

191223
num_blocks = len(encoder_output_strides)
192224

193-
# If encoder has cls token, then concatenate it with the features along the embedding dimension and project it
194-
# back to the feature_dim dimension. Else, ignore the non-existent cls token
195-
blocks = [
196-
ProjectionBlock(in_channels, has_cls_token)
197-
for in_channels in encoder_out_channels
198-
]
225+
# If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them
226+
# according to the readout mode
227+
if readout == "cat":
228+
blocks = [
229+
ReadoutConcatBlock(in_channels, encoder_has_prefix_tokens)
230+
for in_channels in encoder_out_channels
231+
]
232+
elif readout == "add":
233+
blocks = [ReadoutAddBlock() for _ in encoder_out_channels]
234+
elif readout == "ignore":
235+
blocks = [ReadoutIgnoreBlock() for _ in encoder_out_channels]
236+
else:
237+
raise ValueError(
238+
f"Invalid readout mode: {readout}, should be one of: 'cat', 'add', 'ignore'"
239+
)
199240
self.projection_blocks = nn.ModuleList(blocks)
200241

201242
# Upsample factors to resize features to [1/4, 1/8, 1/16, 1/32, ...] scales

segmentation_models_pytorch/decoders/dpt/model.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Optional, Union, Callable, Sequence
1+
import warnings
2+
from typing import Any, Optional, Union, Callable, Sequence, Literal
3+
24
import torch
35

46
from segmentation_models_pytorch.base import (
@@ -43,6 +45,8 @@ class DPT(SegmentationModel):
4345
across the number of blocks in encoder, e.g. if number of blocks is 4 and encoder has 20 blocks, then
4446
encoder_output_indices will be (4, 9, 14, 19). If specified the number of indices should be equal to
4547
encoder_depth. Default is **None**.
48+
decoder_readout: The strategy to utilize the prefix tokens (e.g. cls_token) from the encoder.
49+
Can be one of **"cat"**, **"add"**, or **"ignore"**. Default is **"cat"**.
4650
decoder_intermediate_channels: The number of channels for the intermediate decoder layers. Reduce if you
4751
want to reduce the number of parameters in the decoder. Default is (256, 512, 1024, 1024).
4852
decoder_fusion_channels: The latent dimension to which the encoder features will be projected to before fusion.
@@ -78,6 +82,7 @@ def __init__(
7882
encoder_depth: int = 4,
7983
encoder_weights: Optional[str] = "imagenet",
8084
encoder_output_indices: Optional[list[int]] = None,
85+
decoder_readout: Literal["ignore", "add", "cat"] = "cat",
8186
decoder_intermediate_channels: Sequence[int] = (256, 512, 1024, 1024),
8287
decoder_fusion_channels: int = 256,
8388
in_channels: int = 3,
@@ -94,6 +99,11 @@ def __init__(
9499
f"Only Timm encoders are supported for DPT. Encoder name must start with 'tu-', got {encoder_name}"
95100
)
96101

102+
if decoder_readout not in ["ignore", "add", "cat"]:
103+
raise ValueError(
104+
f"Invalid decoder readout mode. Must be one of: 'ignore', 'add', 'cat'. Got: {decoder_readout}"
105+
)
106+
97107
self.encoder = TimmViTEncoder(
98108
name=encoder_name,
99109
in_channels=in_channels,
@@ -103,12 +113,20 @@ def __init__(
103113
**kwargs,
104114
)
105115

116+
if not self.encoder.has_prefix_tokens and decoder_readout != "ignore":
117+
warnings.warn(
118+
f"Encoder does not have prefix tokens (e.g. cls_token), but `decoder_readout` is set to '{decoder_readout}'. "
119+
f"It's recommended to set `decoder_readout='ignore'` when using a encoder without prefix tokens.",
120+
UserWarning,
121+
)
122+
106123
self.decoder = DPTDecoder(
107124
encoder_out_channels=self.encoder.out_channels,
125+
encoder_output_strides=self.encoder.output_strides,
126+
encoder_has_prefix_tokens=self.encoder.has_prefix_tokens,
127+
readout=decoder_readout,
108128
intermediate_channels=decoder_intermediate_channels,
109129
fusion_channels=decoder_fusion_channels,
110-
encoder_output_strides=self.encoder.output_strides,
111-
has_cls_token=self.encoder.has_class_token,
112130
)
113131

114132
self.segmentation_head = DPTSegmentationHead(

segmentation_models_pytorch/encoders/timm_vit.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,14 @@ def __init__(
129129

130130
# Private attributes for model forward
131131
self._num_prefix_tokens = getattr(self.model, "num_prefix_tokens", 0)
132+
self._has_cls_token = getattr(self.model, "has_cls_token", False)
132133
self._output_indices = output_indices
133134

134135
# Public attributes
135136
self.output_strides = [feature_info[i]["reduction"] for i in output_indices]
136137
self.output_stride = self.output_strides[-1]
137138
self.out_channels = [feature_info[i]["num_chs"] for i in output_indices]
138-
self.has_class_token = getattr(self.model, "has_class_token", False)
139+
self.has_prefix_tokens = self._num_prefix_tokens > 0
139140

140141
@property
141142
def is_fixed_input_size(self) -> bool:
@@ -145,25 +146,22 @@ def is_fixed_input_size(self) -> bool:
145146
def input_size(self) -> int:
146147
return self.model.pretrained_cfg.get("input_size", None)
147148

148-
def _forward_with_cls_token(
149+
def _forward_with_prefix_tokens(
149150
self, x: torch.Tensor
150151
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
151152
intermediate_outputs = self.model.forward_intermediates(
152153
x,
153154
indices=self._output_indices,
154-
return_prefix_tokens=True,
155155
intermediates_only=True,
156+
return_prefix_tokens=True,
156157
)
157158

158159
features = [output[0] for output in intermediate_outputs]
159-
cls_tokens = [output[1] for output in intermediate_outputs]
160-
161-
if self.has_class_token and self._num_prefix_tokens > 1:
162-
cls_tokens = [x[:, 0, :] for x in cls_tokens]
160+
prefix_tokens = [output[1] for output in intermediate_outputs]
163161

164-
return features, cls_tokens
162+
return features, prefix_tokens
165163

166-
def _forward_without_cls_token(self, x: torch.Tensor) -> list[torch.Tensor]:
164+
def _forward_without_prefix_tokens(self, x: torch.Tensor) -> list[torch.Tensor]:
167165
features = self.model.forward_intermediates(
168166
x,
169167
indices=self._output_indices,
@@ -184,10 +182,10 @@ def forward(
184182
tuple[list[torch.Tensor], list[torch.Tensor]]: Tuple of feature maps and cls tokens (if supported) at different scales.
185183
"""
186184

187-
if self.has_class_token:
188-
features, cls_tokens = self._forward_with_cls_token(x)
185+
if self.has_prefix_tokens:
186+
features, prefix_tokens = self._forward_with_prefix_tokens(x)
189187
else:
190188
features = self._forward_without_cls_token(x)
191-
cls_tokens = [None] * len(features)
189+
prefix_tokens = [None] * len(features)
192190

193-
return features, cls_tokens
191+
return features, prefix_tokens

0 commit comments

Comments
 (0)