diff --git a/segmentation_models_pytorch/decoders/dpt/decoder.py b/segmentation_models_pytorch/decoders/dpt/decoder.py index c0b4634b..345ecca1 100644 --- a/segmentation_models_pytorch/decoders/dpt/decoder.py +++ b/segmentation_models_pytorch/decoders/dpt/decoder.py @@ -32,8 +32,8 @@ def forward( features = features.transpose(1, 2).contiguous() if prefix_tokens is not None: - # (batch_size, num_tokens, embed_dim) -> (batch_size, embed_dim) - prefix_tokens = prefix_tokens[:, 0].expand_as(features) + # (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim) + prefix_tokens = prefix_tokens[:, :1].expand_as(features) features = torch.cat([features, prefix_tokens], dim=2) # Project to embedding dimension