File tree Expand file tree Collapse file tree 1 file changed +13
-7
lines changed
segmentation_models_pytorch/decoders/dpt Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Original file line number Diff line number Diff line change @@ -220,7 +220,16 @@ def __init__(
220
220
):
221
221
super ().__init__ ()
222
222
223
- num_blocks = len (encoder_output_strides )
223
+ if not (
224
+ len (encoder_out_channels )
225
+ == len (encoder_output_strides )
226
+ == len (intermediate_channels )
227
+ ):
228
+ raise ValueError (
229
+ "encoder_out_channels, encoder_output_strides and intermediate_channels must have the same length"
230
+ )
231
+
232
+ num_blocks = len (encoder_out_channels )
224
233
225
234
# If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them
226
235
# according to the readout mode
@@ -269,12 +278,9 @@ def forward(
269
278
270
279
# Fusion and progressive upsampling starting from the last processed feature
271
280
processed_features = processed_features [::- 1 ]
272
- for i , fusion_block in enumerate (self .fusion_blocks ):
273
- processed_feature = processed_features [i ]
274
- if i == 0 :
275
- fused_feature = fusion_block (processed_feature )
276
- else :
277
- fused_feature = fusion_block (processed_feature , fused_feature )
281
+ fused_feature = None
282
+ for fusion_block , feature in zip (self .fusion_blocks , processed_features ):
283
+ fused_feature = fusion_block (feature , fused_feature )
278
284
279
285
return fused_feature
280
286
You can’t perform that action at this time.
0 commit comments