Skip to content

Commit 28ea9f8

Browse files
authored
Fix (#1121)
1 parent a8c09f5 commit 28ea9f8

File tree

1 file changed

+2
-2
lines changed
  • segmentation_models_pytorch/decoders/dpt

1 file changed

+2
-2
lines changed

segmentation_models_pytorch/decoders/dpt/decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def forward(
3232
features = features.transpose(1, 2).contiguous()
3333

3434
if prefix_tokens is not None:
35-
# (batch_size, num_tokens, embed_dim) -> (batch_size, embed_dim)
36-
prefix_tokens = prefix_tokens[:, 0].expand_as(features)
35+
# (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim)
36+
prefix_tokens = prefix_tokens[:, :1].expand_as(features)
3737
features = torch.cat([features, prefix_tokens], dim=2)
3838

3939
# Project to embedding dimension

0 commit comments

Comments
 (0)