From 0cab9893c435bb7b7444419cf7549ac5731756ab Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sat, 11 Jan 2025 20:14:28 +0000 Subject: [PATCH 1/7] Fix type hint for models --- segmentation_models_pytorch/base/model.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 6d7bf643..dd69eef3 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -1,8 +1,11 @@ import torch +from typing import TypeVar, Type from . import initialization as init from .hub_mixin import SMPHubMixin +T = TypeVar("T", bound="SegmentationModel") + class SegmentationModel(torch.nn.Module, SMPHubMixin): """Base class for all segmentation models.""" @@ -11,6 +14,11 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): # set to False requires_divisible_input_shape = True + # Fix type-hint for models, to avoid HubMixin signature + def __new__(cls: Type[T], *args, **kwargs) -> T: + instance = super().__new__(cls, *args, **kwargs) + return instance + def initialize(self): init.initialize_decoder(self.decoder) init.initialize_head(self.segmentation_head) @@ -42,7 +50,7 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if not torch.jit.is_tracing() or self.requires_divisible_input_shape: + if torch.jit.is_tracing() or self.requires_divisible_input_shape: self.check_input_shape(x) features = self.encoder(x) From b2166ea86b6897aedcd3a84e97daf8636e6dd6fe Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sat, 11 Jan 2025 20:15:13 +0000 Subject: [PATCH 2/7] Use inference mode in tests --- tests/models/test_segformer.py | 2 +- tests/test_losses.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 3ca5016c..5195f050 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -21,7 +21,7 @@ def test_load_pretrained(self): sample = torch.ones([1, 3, 512, 512]).to(default_device) - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) self.assertEqual(output.shape, (1, 150, 512, 512)) diff --git a/tests/test_losses.py b/tests/test_losses.py index 5c3ad75a..94d85d5c 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -93,7 +93,7 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): assert float(actual) == pytest.approx(expected, eps) -@torch.no_grad() +@torch.inference_mode() def test_dice_loss_binary(): eps = 1e-5 criterion = DiceLoss(mode=smp.losses.BINARY_MODE, from_logits=False) @@ -131,7 +131,7 @@ def test_dice_loss_binary(): assert float(loss) == pytest.approx(1.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_tversky_loss_binary(): eps = 1e-5 # with alpha=0.5; beta=0.5 it is equal to DiceLoss @@ -172,7 +172,7 @@ def test_tversky_loss_binary(): assert float(loss) == pytest.approx(1.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_binary_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode=smp.losses.BINARY_MODE, from_logits=False) @@ -210,7 +210,7 @@ def test_binary_jaccard_loss(): assert float(loss) == pytest.approx(1.0, eps) -@torch.no_grad() +@torch.inference_mode() def test_multiclass_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=False) @@ -237,7 +237,7 @@ def test_multiclass_jaccard_loss(): assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_multilabel_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode=smp.losses.MULTILABEL_MODE, from_logits=False) @@ -263,7 +263,7 @@ def test_multilabel_jaccard_loss(): assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_soft_ce_loss(): criterion = SoftCrossEntropyLoss(smooth_factor=0.1, ignore_index=-100) @@ -276,7 +276,7 @@ def test_soft_ce_loss(): assert float(loss) == pytest.approx(1.0125, abs=0.0001) -@torch.no_grad() +@torch.inference_mode() def test_soft_bce_loss(): criterion = SoftBCEWithLogitsLoss(smooth_factor=0.1, ignore_index=-100) @@ -287,7 +287,7 @@ def test_soft_bce_loss(): assert float(loss) == pytest.approx(0.7201, abs=0.0001) -@torch.no_grad() +@torch.inference_mode() def test_binary_mcc_loss(): eps = 1e-5 criterion = MCCLoss(eps=eps) From 26725de10a494d99aabd907a33fd58a56f9818a8 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sat, 11 Jan 2025 20:16:26 +0000 Subject: [PATCH 3/7] Add test for any resolution (not divisible by 32) --- tests/models/base.py | 68 +++++++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index 02e17303..93aac7b0 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -57,6 +57,12 @@ def decoder_channels(self): def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): return torch.rand(batch_size, num_channels, height, width) + @lru_cache + def get_default_model(self): + model = smp.create_model(self.model_type, self.test_encoder_name) + model = model.to(default_device) + return model + def test_forward_backward(self): sample = self._get_sample( batch_size=self.default_batch_size, @@ -64,9 +70,8 @@ def test_forward_backward(self): height=self.default_height, width=self.default_width, ).to(default_device) - model = smp.create_model( - arch=self.model_type, encoder_name=self.test_encoder_name - ).to(default_device) + + model = self.get_default_model() # check default in_channels=3 output = model(sample) @@ -91,14 +96,19 @@ def test_in_channels_and_depth_and_out_classes( if self.model_type in ["unet", "unetplusplus", "manet"]: kwargs = {"decoder_channels": self.decoder_channels[:depth]} - model = smp.create_model( - arch=self.model_type, - encoder_name=self.test_encoder_name, - encoder_depth=depth, - in_channels=in_channels, - classes=classes, - **kwargs, - ).to(default_device) + model = ( + smp.create_model( + arch=self.model_type, + encoder_name=self.test_encoder_name, + encoder_depth=depth, + in_channels=in_channels, + classes=classes, + **kwargs, + ) + .to(default_device) + .eval() + ) + sample = self._get_sample( batch_size=self.default_batch_size, num_channels=in_channels, @@ -107,7 +117,7 @@ def test_in_channels_and_depth_and_out_classes( ).to(default_device) # check in channels correctly set - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) self.assertEqual(output.shape[1], classes) @@ -122,7 +132,8 @@ def test_classification_head(self): "dropout": 0.5, "activation": "sigmoid", }, - ).to(default_device) + ) + model = model.to(default_device).eval() self.assertIsNotNone(model.classification_head) self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d) @@ -139,17 +150,34 @@ def test_classification_head(self): width=self.default_width, ).to(default_device) - with torch.no_grad(): + with torch.inference_mode(): _, cls_probs = model(sample) self.assertEqual(cls_probs.shape[1], 10) + def test_any_resolution(self): + model = self.get_default_model() + if model.requires_divisible_input_shape: + self.skipTest("Model requires divisible input shape") + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height + 3, + width=self.default_width + 7, + ).to(default_device) + + with torch.inference_mode(): + output = model(sample) + + self.assertEqual(output.shape[2], self.default_height + 3) + self.assertEqual(output.shape[3], self.default_width + 7) + @requires_torch_greater_or_equal("2.0.1") def test_save_load_with_hub_mixin(self): # instantiate model - model = smp.create_model( - arch=self.model_type, encoder_name=self.test_encoder_name - ).to(default_device) + model = self.get_default_model() + model.eval() # save model with tempfile.TemporaryDirectory() as tmpdir: @@ -157,6 +185,8 @@ def test_save_load_with_hub_mixin(self): tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99} ) restored_model = smp.from_pretrained(tmpdir).to(default_device) + restored_model.eval() + with open(os.path.join(tmpdir, "README.md"), "r") as f: readme = f.read() @@ -168,7 +198,7 @@ def test_save_load_with_hub_mixin(self): width=self.default_width, ).to(default_device) - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) restored_output = restored_model(sample) @@ -197,7 +227,7 @@ def test_preserve_forward_output(self): output_tensor = torch.load(output_tensor_path, weights_only=True) output_tensor = output_tensor.to(default_device) - with torch.no_grad(): + with torch.inference_mode(): output = model(input_tensor) self.assertEqual(output.shape, output_tensor.shape) From 5b75f1152990342a01cbb518efe75280458b4a2c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sat, 11 Jan 2025 20:16:50 +0000 Subject: [PATCH 4/7] Use inference mode in tests --- tests/encoders/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 0f762cf4..1391aaa3 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -82,7 +82,7 @@ def test_in_channels(self): encoder.eval() # forward - with torch.no_grad(): + with torch.inference_mode(): encoder.forward(sample) def test_depth(self): @@ -110,7 +110,7 @@ def test_depth(self): encoder.eval() # forward - with torch.no_grad(): + with torch.inference_mode(): features = encoder.forward(sample) # check number of features @@ -187,7 +187,7 @@ def test_dilated(self): encoder.eval() # forward - with torch.no_grad(): + with torch.inference_mode(): features = encoder.forward(sample) height_strides, width_strides = self.get_features_output_strides( From 6b2ca909a774ec01e3a1c64dd7b7804aecfb3f9d Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sat, 11 Jan 2025 20:17:32 +0000 Subject: [PATCH 5/7] Enable any res for Unet and better docs --- .../decoders/unet/decoder.py | 95 ++++++++++++------- .../decoders/unet/model.py | 46 +++++++-- 2 files changed, 101 insertions(+), 40 deletions(-) diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index 33061542..4c2a6711 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -2,19 +2,22 @@ import torch.nn as nn import torch.nn.functional as F +from typing import Optional, Sequence from segmentation_models_pytorch.base import modules as md class DecoderBlock(nn.Module): def __init__( self, - in_channels, - skip_channels, - out_channels, - use_batchnorm=True, - attention_type=None, + in_channels: int, + skip_channels: int, + out_channels: int, + use_batchnorm: bool = True, + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", ): super().__init__() + self.interpolate_mode = interpolation_mode self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, @@ -34,19 +37,32 @@ def __init__( ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") - if skip is not None: - x = torch.cat([x, skip], dim=1) - x = self.attention1(x) - x = self.conv1(x) - x = self.conv2(x) - x = self.attention2(x) - return x + def forward( + self, + feature_map: torch.Tensor, + target_height: int, + target_width: int, + skip_connection: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Upsample feature map to the given spatial shape, concatenate with skip connection, + apply attention block (if specified) and then apply two convolutions. + """ + feature_map = F.interpolate( + feature_map, size=(target_height, target_width), mode=self.interpolate_mode + ) + if skip_connection is not None: + feature_map = torch.cat([feature_map, skip_connection], dim=1) + feature_map = self.attention1(feature_map) + feature_map = self.conv1(feature_map) + feature_map = self.conv2(feature_map) + feature_map = self.attention2(feature_map) + return feature_map class CenterBlock(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + """Center block of the Unet decoder. Applied to the last feature map of the encoder.""" + + def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): conv1 = md.Conv2dReLU( in_channels, out_channels, @@ -67,12 +83,12 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True): class UnetDecoder(nn.Module): def __init__( self, - encoder_channels, - decoder_channels, - n_blocks=5, - use_batchnorm=True, - attention_type=None, - center=False, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + use_batchnorm: bool = True, + attention_type: Optional[str] = None, + add_center_block: bool = False, ): super().__init__() @@ -94,7 +110,7 @@ def __init__( skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels - if center: + if add_center_block: self.center = CenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) @@ -102,23 +118,36 @@ def __init__( self.center = nn.Identity() # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) - blocks = [ - DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) - for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) - ] - self.blocks = nn.ModuleList(blocks) - - def forward(self, *features): + self.blocks = nn.ModuleList() + for block_in_channels, block_skip_channels, block_out_channels in zip( + in_channels, skip_channels, out_channels + ): + block = DecoderBlock( + block_in_channels, + block_skip_channels, + block_out_channels, + use_batchnorm=use_batchnorm, + attention_type=attention_type, + ) + self.blocks.append(block) + + def forward(self, *features: torch.Tensor) -> torch.Tensor: + # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...] + spatial_shapes = [feature.shape[2:] for feature in features] + spatial_shapes = spatial_shapes[::-1] + features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder head = features[0] - skips = features[1:] + skip_connections = features[1:] x = self.center(head) + for i, decoder_block in enumerate(self.blocks): - skip = skips[i] if i < len(skips) else None - x = decoder_block(x, skip) + # upsample to the next spatial shape + height, width = spatial_shapes[i + 1] + skip_connection = skip_connections[i] if i < len(skip_connections) else None + x = decoder_block(x, height, width, skip_connection=skip_connection) return x diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 547581eb..660eb21d 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, Tuple, Callable +from typing import Any, Optional, Union, Callable, Sequence from segmentation_models_pytorch.base import ( ClassificationHead, @@ -12,10 +12,21 @@ class Unet(SegmentationModel): - """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* - and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial - resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* - for fusing decoder blocks with skip connections. + """ + U-Net is a fully convolutional neural network architecture designed for semantic image segmentation. + + It consists of two main parts: + + 1. An encoder (downsampling path) that extracts increasingly abstract features + 2. A decoder (upsampling path) that gradually recovers spatial details + + The key is the use of skip connections between corresponding encoder and decoder layers. + These connections allow the decoder to access fine-grained details from earlier encoder layers, + which helps produce more precise segmentation masks. + + The skip connections work by concatenating feature maps from the encoder directly into the decoder + at corresponding resolutions. This helps preserve important spatial information that would + otherwise be lost during the encoding process. Args: encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) @@ -51,11 +62,31 @@ class Unet(SegmentationModel): Returns: ``torch.nn.Module``: Unet + Example: + .. code-block:: python + + import torch + import segmentation_models_pytorch as smp + + model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5) + model.eval() + + # generate random images + images = torch.rand(2, 3, 256, 256) + + with torch.inference_mode(): + mask = model(images) + + print(mask.shape) + # torch.Size([2, 5, 256, 256]) + .. _Unet: https://arxiv.org/abs/1505.04597 """ + requires_divisible_input_shape = False + @supports_config_loading def __init__( self, @@ -63,7 +94,7 @@ def __init__( encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, - decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, @@ -81,12 +112,13 @@ def __init__( **kwargs, ) + add_center_block = encoder_name.startswith("vgg") self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, use_batchnorm=decoder_use_batchnorm, - center=True if encoder_name.startswith("vgg") else False, + add_center_block=add_center_block, attention_type=decoder_attention_type, ) From eb81c1f34b81ffdd492cb2be58cbfbfe189f71c5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sat, 11 Jan 2025 20:40:59 +0000 Subject: [PATCH 6/7] Fix check_input_shape condition --- segmentation_models_pytorch/base/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index dd69eef3..a25ed30a 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -50,7 +50,7 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if torch.jit.is_tracing() or self.requires_divisible_input_shape: + if not torch.jit.is_tracing() and self.requires_divisible_input_shape: self.check_input_shape(x) features = self.encoder(x) From d5a80df0bf1f19892e79af4907ee8f556d603403 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 13 Jan 2025 16:42:54 +0000 Subject: [PATCH 7/7] Interpolation for unet --- .../decoders/unet/decoder.py | 27 ++++++++++++------- .../decoders/unet/model.py | 4 +++ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index 4c2a6711..e6bf4d16 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -6,7 +6,9 @@ from segmentation_models_pytorch.base import modules as md -class DecoderBlock(nn.Module): +class UnetDecoderBlock(nn.Module): + """A decoder block in the U-Net architecture that performs upsampling and feature fusion.""" + def __init__( self, in_channels: int, @@ -17,7 +19,7 @@ def __init__( interpolation_mode: str = "nearest", ): super().__init__() - self.interpolate_mode = interpolation_mode + self.interpolation_mode = interpolation_mode self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, @@ -44,11 +46,10 @@ def forward( target_width: int, skip_connection: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Upsample feature map to the given spatial shape, concatenate with skip connection, - apply attention block (if specified) and then apply two convolutions. - """ feature_map = F.interpolate( - feature_map, size=(target_height, target_width), mode=self.interpolate_mode + feature_map, + size=(target_height, target_width), + mode=self.interpolation_mode, ) if skip_connection is not None: feature_map = torch.cat([feature_map, skip_connection], dim=1) @@ -59,7 +60,7 @@ def forward( return feature_map -class CenterBlock(nn.Sequential): +class UnetCenterBlock(nn.Sequential): """Center block of the Unet decoder. Applied to the last feature map of the encoder.""" def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): @@ -81,6 +82,12 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr class UnetDecoder(nn.Module): + """The decoder part of the U-Net architecture. + + Takes encoded features from different stages of the encoder and progressively upsamples them while + combining with skip connections. This helps preserve fine-grained details in the final segmentation. + """ + def __init__( self, encoder_channels: Sequence[int], @@ -89,6 +96,7 @@ def __init__( use_batchnorm: bool = True, attention_type: Optional[str] = None, add_center_block: bool = False, + interpolation_mode: str = "nearest", ): super().__init__() @@ -111,7 +119,7 @@ def __init__( out_channels = decoder_channels if add_center_block: - self.center = CenterBlock( + self.center = UnetCenterBlock( head_channels, head_channels, use_batchnorm=use_batchnorm ) else: @@ -122,12 +130,13 @@ def __init__( for block_in_channels, block_skip_channels, block_out_channels in zip( in_channels, skip_channels, out_channels ): - block = DecoderBlock( + block = UnetDecoderBlock( block_in_channels, block_skip_channels, block_out_channels, use_batchnorm=use_batchnorm, attention_type=attention_type, + interpolation_mode=interpolation_mode, ) self.blocks.append(block) diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 660eb21d..4b30527d 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -44,6 +44,8 @@ class Unet(SegmentationModel): Available options are **True, False, "inplace"** decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). + decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. @@ -96,6 +98,7 @@ def __init__( decoder_use_batchnorm: bool = True, decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, + decoder_interpolation_mode: str = "nearest", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, Callable]] = None, @@ -120,6 +123,7 @@ def __init__( use_batchnorm=decoder_use_batchnorm, add_center_block=add_center_block, attention_type=decoder_attention_type, + interpolation_mode=decoder_interpolation_mode, ) self.segmentation_head = SegmentationHead(