Skip to content

Commit 5603707

Browse files
committed
Fix DPT tests
1 parent 165b9c0 commit 5603707

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

segmentation_models_pytorch/decoders/dpt/decoder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,21 +258,23 @@ def __init__(
258258
self.fusion_blocks = nn.ModuleList(fusion_blocks)
259259

260260
def forward(
261-
self, features: list[torch.Tensor], cls_tokens: list[Optional[torch.Tensor]]
261+
self, features: list[torch.Tensor], prefix_tokens: list[Optional[torch.Tensor]]
262262
) -> torch.Tensor:
263263
# Process the encoder features to scale of [1/4, 1/8, 1/16, 1/32, ...]
264264
processed_features = []
265-
for i, (feature, cls_token) in enumerate(zip(features, cls_tokens)):
266-
projected_feature = self.projection_blocks[i](feature, cls_token)
265+
for i, (feature, prefix_tokens_i) in enumerate(zip(features, prefix_tokens)):
266+
projected_feature = self.projection_blocks[i](feature, prefix_tokens_i)
267267
processed_feature = self.reassemble_blocks[i](projected_feature)
268268
processed_features.append(processed_feature)
269269

270270
# Fusion and progressive upsampling starting from the last processed feature
271-
previous_feature = None
272271
processed_features = processed_features[::-1]
273-
for fusion_block, feature in zip(self.fusion_blocks, processed_features):
274-
fused_feature = fusion_block(feature, previous_feature)
275-
previous_feature = fused_feature
272+
for i, fusion_block in enumerate(self.fusion_blocks):
273+
processed_feature = processed_features[i]
274+
if i == 0:
275+
fused_feature = fusion_block(processed_feature)
276+
else:
277+
fused_feature = fusion_block(processed_feature, fused_feature)
276278

277279
return fused_feature
278280

segmentation_models_pytorch/decoders/dpt/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ class DPT(SegmentationModel):
7171
7272
"""
7373

74-
_is_torch_scriptable = True
74+
# fails for encoders with prefix tokens
75+
_is_torch_scriptable = False
7576
_is_torch_compilable = True
7677
requires_divisible_input_shape = True
7778

@@ -155,8 +156,8 @@ def forward(self, x):
155156
):
156157
self.check_input_shape(x)
157158

158-
features, cls_tokens = self.encoder(x)
159-
decoder_output = self.decoder(features, cls_tokens)
159+
features, prefix_tokens = self.encoder(x)
160+
decoder_output = self.decoder(features, prefix_tokens)
160161
masks = self.segmentation_head(decoder_output)
161162

162163
if self.classification_head is not None:

tests/models/test_dpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TestDPTModel(base.BaseModelTester):
2222

2323
@property
2424
def hub_checkpoint(self):
25-
return "vedantdalimkar/DPT"
25+
return "smp-hub/dpt-large-ade20k"
2626

2727
@slow_test
2828
@requires_torch_greater_or_equal("2.0.1")

0 commit comments

Comments
 (0)