Skip to content

Commit 9518964

Browse files
committed
Refactor a bit
1 parent 5603707 commit 9518964

File tree

1 file changed

+13
-7
lines changed
  • segmentation_models_pytorch/decoders/dpt

1 file changed

+13
-7
lines changed

segmentation_models_pytorch/decoders/dpt/decoder.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,16 @@ def __init__(
220220
):
221221
super().__init__()
222222

223-
num_blocks = len(encoder_output_strides)
223+
if not (
224+
len(encoder_out_channels)
225+
== len(encoder_output_strides)
226+
== len(intermediate_channels)
227+
):
228+
raise ValueError(
229+
"encoder_out_channels, encoder_output_strides and intermediate_channels must have the same length"
230+
)
231+
232+
num_blocks = len(encoder_out_channels)
224233

225234
# If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them
226235
# according to the readout mode
@@ -269,12 +278,9 @@ def forward(
269278

270279
# Fusion and progressive upsampling starting from the last processed feature
271280
processed_features = processed_features[::-1]
272-
for i, fusion_block in enumerate(self.fusion_blocks):
273-
processed_feature = processed_features[i]
274-
if i == 0:
275-
fused_feature = fusion_block(processed_feature)
276-
else:
277-
fused_feature = fusion_block(processed_feature, fused_feature)
281+
fused_feature = None
282+
for fusion_block, feature in zip(self.fusion_blocks, processed_features):
283+
fused_feature = fusion_block(feature, fused_feature)
278284

279285
return fused_feature
280286

0 commit comments

Comments
 (0)