From d08e816e24a1f0f133632eb1868dce84aad53830 Mon Sep 17 00:00:00 2001 From: qubvel Date: Wed, 9 Apr 2025 09:12:02 +0000 Subject: [PATCH] Fix --- segmentation_models_pytorch/decoders/dpt/decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/segmentation_models_pytorch/decoders/dpt/decoder.py b/segmentation_models_pytorch/decoders/dpt/decoder.py index c0b4634b..345ecca1 100644 --- a/segmentation_models_pytorch/decoders/dpt/decoder.py +++ b/segmentation_models_pytorch/decoders/dpt/decoder.py @@ -32,8 +32,8 @@ def forward( features = features.transpose(1, 2).contiguous() if prefix_tokens is not None: - # (batch_size, num_tokens, embed_dim) -> (batch_size, embed_dim) - prefix_tokens = prefix_tokens[:, 0].expand_as(features) + # (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim) + prefix_tokens = prefix_tokens[:, :1].expand_as(features) features = torch.cat([features, prefix_tokens], dim=2) # Project to embedding dimension