Skip to content

Commit 71e2acb

Browse files
Code refactor
1 parent c47bdfb commit 71e2acb

File tree

8 files changed

+233
-127
lines changed

8 files changed

+233
-127
lines changed

segmentation_models_pytorch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
PAN,
3636
UPerNet,
3737
Segformer,
38-
DPT
38+
DPT,
3939
]
4040
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}
4141

segmentation_models_pytorch/decoders/dpt/decoder.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,16 @@ def __init__(
196196
encoder_output_stride: int,
197197
feature_dim: int = 256,
198198
encoder_depth: int = 4,
199-
prefix_token_supported: bool = False,
199+
cls_token_supported: bool = False,
200200
):
201201
super().__init__()
202202

203-
self.prefix_token_supported = prefix_token_supported
203+
self.cls_token_supported = cls_token_supported
204204

205205
# If encoder has cls token, then concatenate it with the features along the embedding dimension and project it
206206
# back to the feature_dim dimension. Else, ignore the non-existent cls token
207207

208-
if prefix_token_supported:
208+
if cls_token_supported:
209209
self.readout_blocks = nn.ModuleList(
210210
[
211211
ProjectionReadout(
@@ -246,9 +246,8 @@ def __init__(
246246
)
247247

248248
def forward(
249-
self, encoder_output: list[list[torch.Tensor], list[torch.Tensor]]
249+
self, features: list[torch.Tensor], cls_tokens: list[torch.Tensor]
250250
) -> torch.Tensor:
251-
features, cls_tokens = encoder_output
252251
processed_features = []
253252

254253
# Process the encoder features to scale of [1/32,1/16,1/8,1/4]

segmentation_models_pytorch/decoders/dpt/model.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import Any, Optional, Union, Callable
2+
import torch
23

34
from segmentation_models_pytorch.base import (
45
ClassificationHead,
56
SegmentationHead,
67
SegmentationModel,
78
)
89
from segmentation_models_pytorch.encoders import get_encoder
10+
from segmentation_models_pytorch.base.utils import is_torch_compiling
911
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
1012
from .decoder import DPTDecoder
1113

@@ -46,8 +48,8 @@ class DPT(SegmentationModel):
4648
(could be **None** to return logits)
4749
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with
4850
``None`` values are pruned before passing.
49-
allow_downsampling : Allow ViT encoder to have progressive downsampling. Set to False for DPT as the architecture
50-
requires all encoder feature outputs to have the same spatial shape.
51+
allow_downsampling : Allow ViT encoder to have progressive spatial downsampling for it's representations.
52+
Set to False for DPT as the architecture requires all encoder feature outputs to have the same spatial shape.
5153
allow_output_stride_not_power_of_two : Allow ViT encoders with output_stride not being a power of 2. This
5254
is set False for DPT as the architecture requires the encoder output features to have an output stride of
5355
[1/32,1/16,1/8,1/4]
@@ -58,6 +60,10 @@ class DPT(SegmentationModel):
5860
5961
"""
6062

63+
_is_torch_scriptable = False
64+
_is_torch_compilable = False
65+
requires_divisible_input_shape = True
66+
6167
@supports_config_loading
6268
def __init__(
6369
self,
@@ -84,17 +90,17 @@ def __init__(
8490
**kwargs,
8591
)
8692

87-
transformer_embed_dim = self.encoder.embed_dim
88-
encoder_output_stride = self.encoder.output_stride
89-
cls_token_supported = self.encoder.prefix_token_supported
93+
self.transformer_embed_dim = self.encoder.embed_dim
94+
self.encoder_output_stride = self.encoder.output_stride
95+
self.cls_token_supported = self.encoder.cls_token_supported
9096

9197
self.decoder = DPTDecoder(
9298
encoder_name=encoder_name,
93-
transformer_embed_dim=transformer_embed_dim,
99+
transformer_embed_dim=self.transformer_embed_dim,
94100
feature_dim=feature_dim,
95101
encoder_depth=encoder_depth,
96-
encoder_output_stride=encoder_output_stride,
97-
prefix_token_supported=cls_token_supported,
102+
encoder_output_stride=self.encoder_output_stride,
103+
cls_token_supported=self.cls_token_supported,
98104
)
99105

100106
self.segmentation_head = SegmentationHead(
@@ -114,3 +120,23 @@ def __init__(
114120

115121
self.name = "dpt-{}".format(encoder_name)
116122
self.initialize()
123+
124+
def forward(self, x):
125+
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
126+
127+
if not (
128+
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
129+
):
130+
self.check_input_shape(x)
131+
132+
features, cls_tokens = self.encoder(x)
133+
134+
decoder_output = self.decoder(features, cls_tokens)
135+
136+
masks = self.segmentation_head(decoder_output)
137+
138+
if self.classification_head is not None:
139+
labels = self.classification_head(features[-1])
140+
return masks, labels
141+
142+
return masks

segmentation_models_pytorch/encoders/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
9292
in_channels=in_channels,
9393
depth=depth,
9494
pretrained=weights is not None,
95-
output_stride = output_stride,
95+
output_stride=output_stride,
9696
**kwargs,
9797
)
9898
return encoder

segmentation_models_pytorch/encoders/timm_vit.py

+47-44
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional
1+
from typing import Any, Optional, Union
22

33
import timm
44
import torch
@@ -15,17 +15,17 @@ class TimmViTEncoder(nn.Module):
1515
- Ensures consistent multi-level feature extraction across all ViT models.
1616
"""
1717

18-
_is_torch_scriptable = True
18+
_is_torch_scriptable = False
1919
_is_torch_exportable = True
20-
_is_torch_compilable = True
20+
_is_torch_compilable = False
2121

2222
def __init__(
2323
self,
2424
name: str,
2525
pretrained: bool = True,
2626
in_channels: int = 3,
2727
depth: int = 4,
28-
output_indices: Optional[list[int] | int] = None,
28+
output_indices: Optional[Union[list[int], int]] = None,
2929
**kwargs: dict[str, Any],
3030
):
3131
"""
@@ -49,16 +49,14 @@ def __init__(
4949
super().__init__()
5050
self.name = name
5151

52-
output_stride = kwargs.pop("output_stride",None)
52+
output_stride = kwargs.pop("output_stride", None)
5353
if output_stride is not None:
54-
raise ValueError(
55-
"Dilated mode not supported, set output stride to None"
56-
)
54+
raise ValueError("Dilated mode not supported, set output stride to None")
5755

5856
# Default model configuration for feature extraction
5957
common_kwargs = dict(
6058
in_chans=in_channels,
61-
features_only=True,
59+
features_only=False,
6260
pretrained=pretrained,
6361
out_indices=tuple(range(depth)),
6462
)
@@ -76,6 +74,23 @@ def __init__(
7674
feature_info = tmp_model.feature_info
7775
model_num_blocks = len(feature_info)
7876

77+
if output_indices is not None:
78+
if isinstance(output_indices, int):
79+
output_indices = list(output_indices)
80+
81+
for output_index in output_indices:
82+
if output_indices < 0 or output_indices > model_num_blocks:
83+
raise ValueError(
84+
f"Output indices for feature extraction should be greater than 0 and less \
85+
than the number of blocks in the model ({model_num_blocks}), got {output_index}"
86+
)
87+
88+
if len(output_indices) != depth:
89+
raise ValueError(
90+
f"Length of output indices for feature extraction should be equal to the depth of the encoder\
91+
architecture, got output indices length - {len(output_indices)}, encoder depth - {depth}"
92+
)
93+
7994
if depth > model_num_blocks:
8095
raise ValueError(
8196
f"Depth of the encoder cannot exceed the number of blocks in the model \
@@ -87,9 +102,6 @@ def __init__(
87102
int((model_num_blocks / 4) * index) - 1 for index in range(1, depth + 1)
88103
]
89104

90-
if isinstance(output_indices,int):
91-
output_indices = list(output_indices)
92-
93105
common_kwargs["out_indices"] = self.out_indices = output_indices
94106
feature_info_obj = timm.models.FeatureInfo(
95107
feature_info=feature_info, out_indices=output_indices
@@ -109,18 +121,16 @@ def __init__(
109121
self._output_stride = reduction_scales[0]
110122

111123
if (
112-
int(self._output_stride).bit_count() != 1
124+
bin(self._output_stride).count("1") != 1
113125
and not allow_output_stride_not_power_of_two
114126
):
115127
raise ValueError(
116128
f"Models with stride which is not a power of 2 are not supported, \
117129
got output stride {self._output_stride}"
118130
)
119131

120-
self.prefix_token_supported = getattr(tmp_model, "has_class_token", False)
132+
self.cls_token_supported = getattr(tmp_model, "has_class_token", False)
121133
self.num_prefix_tokens = getattr(tmp_model, "num_prefix_tokens", 0)
122-
if self.prefix_token_supported:
123-
common_kwargs["features_only"] = False
124134

125135
self.model = timm.create_model(
126136
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
@@ -131,47 +141,40 @@ def __init__(
131141
self._depth = depth
132142
self._embed_dim = tmp_model.embed_dim
133143

134-
def forward(self, x: torch.Tensor) -> list[list[torch.Tensor], list[torch.Tensor]]:
144+
def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
135145
"""
136146
Forward pass to extract multi-stage features.
137147
138148
Args:
139149
x (torch.Tensor): Input tensor of shape (B, C, H, W).
140150
141151
Returns:
142-
list[torch.Tensor]: List of feature maps at different scales.
152+
tuple[list[torch.Tensor], list[torch.Tensor]]: Tuple of feature maps and cls tokens (if supported) at different scales.
143153
"""
144-
if self.prefix_token_supported:
145-
intermediate_outputs = self.model.forward_intermediates(
146-
x,
147-
indices=self.out_indices,
148-
return_prefix_tokens=True,
149-
intermediates_only=True,
150-
)
151-
features, cls_tokens = zip(*intermediate_outputs)
152-
153-
# Convert NHWC to NCHW if needed
154-
if self._is_channel_last:
155-
features = [
156-
feature.permute(0, 3, 1, 2).contiguous() for feature in features
157-
]
158-
159-
if self.num_prefix_tokens > 1:
160-
cls_tokens = [cls_token[:, 0, :] for cls_token in cls_tokens]
154+
intermediate_outputs = self.model.forward_intermediates(
155+
x,
156+
indices=self.out_indices,
157+
return_prefix_tokens=True,
158+
intermediates_only=True,
159+
)
161160

162-
return [features, cls_tokens]
161+
cls_tokens = [None] * len(self.out_indices)
163162

164-
features = self.model(x)
163+
if self.num_prefix_tokens > 0:
164+
features, prefix_tokens = zip(*intermediate_outputs)
165+
if self.cls_token_supported:
166+
if self.num_prefix_tokens == 1:
167+
cls_tokens = prefix_tokens
165168

166-
# Convert NHWC to NCHW if needed
167-
if self._is_channel_last:
168-
features = [
169-
feature.permute(0, 3, 1, 2).contiguous() for feature in features
170-
]
169+
elif self.num_prefix_tokens > 1:
170+
cls_tokens = [
171+
prefix_token[:, 0, :] for prefix_token in prefix_tokens
172+
]
171173

172-
cls_tokens = [None] * len(features)
174+
else:
175+
features = intermediate_outputs
173176

174-
return [features, cls_tokens]
177+
return features, cls_tokens
175178

176179
@property
177180
def embed_dim(self) -> int:

0 commit comments

Comments
 (0)