Skip to content

Commit d6cf66b

Browse files
Extend usage of interpolation_mode to MAnet / UnetPlusPlus / FPN and align PAN (#1108)
* Extend usage of interpolation_mode to MAnet / UnetPlusPlus / FPN and align PAN * Catch upscale_mode usage in PAN and overwrite decoder_interpolation_mode * Fixup * Rename for FPN * Rename for PAN * Rename for Unet, Unet++, Manet * Add test for fpn * Fixup * Add manet test * Test for PAN * Add Unet++ test * Add test unet * Refine test for PAN --------- Co-authored-by: qubvel <[email protected]>
1 parent 44505fd commit d6cf66b

File tree

14 files changed

+198
-20
lines changed

14 files changed

+198
-20
lines changed

segmentation_models_pytorch/decoders/fpn/decoder.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
2525

2626

2727
class FPNBlock(nn.Module):
28-
def __init__(self, pyramid_channels: int, skip_channels: int):
28+
def __init__(
29+
self,
30+
pyramid_channels: int,
31+
skip_channels: int,
32+
interpolation_mode: str = "nearest",
33+
):
2934
super().__init__()
3035
self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
36+
self.interpolation_mode = interpolation_mode
3137

3238
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
33-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
39+
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
3440
skip = self.skip_conv(skip)
3541
x = x + skip
3642
return x
@@ -84,6 +90,7 @@ def __init__(
8490
segmentation_channels: int = 128,
8591
dropout: float = 0.2,
8692
merge_policy: Literal["add", "cat"] = "add",
93+
interpolation_mode: str = "nearest",
8794
):
8895
super().__init__()
8996

@@ -103,9 +110,9 @@ def __init__(
103110
encoder_channels = encoder_channels[: encoder_depth + 1]
104111

105112
self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1)
106-
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
107-
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
108-
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
113+
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1], interpolation_mode)
114+
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2], interpolation_mode)
115+
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3], interpolation_mode)
109116

110117
self.seg_blocks = nn.ModuleList(
111118
[

segmentation_models_pytorch/decoders/fpn/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class FPN(SegmentationModel):
2828
decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add**
2929
and **cat**
3030
decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_
31+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
32+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
3133
in_channels: A number of input channels for the model, default is 3 (RGB images)
3234
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3335
activation: An activation function to apply after the final convolution layer.
@@ -62,6 +64,7 @@ def __init__(
6264
decoder_segmentation_channels: int = 128,
6365
decoder_merge_policy: str = "add",
6466
decoder_dropout: float = 0.2,
67+
decoder_interpolation: str = "nearest",
6568
in_channels: int = 3,
6669
classes: int = 1,
6770
activation: Optional[str] = None,
@@ -92,6 +95,7 @@ def __init__(
9295
segmentation_channels=decoder_segmentation_channels,
9396
dropout=decoder_dropout,
9497
merge_policy=decoder_merge_policy,
98+
interpolation_mode=decoder_interpolation,
9599
)
96100

97101
self.segmentation_head = SegmentationHead(

segmentation_models_pytorch/decoders/manet/decoder.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
in_channels: int,
5050
skip_channels: int,
5151
out_channels: int,
52+
interpolation_mode: str = "nearest",
5253
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
5354
reduction: int = 16,
5455
):
@@ -99,12 +100,13 @@ def __init__(
99100
padding=1,
100101
use_norm=use_norm,
101102
)
103+
self.interpolation_mode = interpolation_mode
102104

103105
def forward(
104106
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
105107
) -> torch.Tensor:
106108
x = self.hl_conv(x)
107-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
109+
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
108110
attention_hl = self.SE_hl(x)
109111
if skip is not None:
110112
attention_ll = self.SE_ll(skip)
@@ -122,6 +124,7 @@ def __init__(
122124
in_channels: int,
123125
skip_channels: int,
124126
out_channels: int,
127+
interpolation_mode: str = "nearest",
125128
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
126129
):
127130
super().__init__()
@@ -139,11 +142,12 @@ def __init__(
139142
padding=1,
140143
use_norm=use_norm,
141144
)
145+
self.interpolation_mode = interpolation_mode
142146

143147
def forward(
144148
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
145149
) -> torch.Tensor:
146-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
150+
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
147151
if skip is not None:
148152
x = torch.cat([x, skip], dim=1)
149153
x = self.conv1(x)
@@ -160,6 +164,7 @@ def __init__(
160164
reduction: int = 16,
161165
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
162166
pab_channels: int = 64,
167+
interpolation_mode: str = "nearest",
163168
):
164169
super().__init__()
165170

@@ -185,7 +190,9 @@ def __init__(
185190
self.center = PABBlock(head_channels, pab_channels=pab_channels)
186191

187192
# combine decoder keyword arguments
188-
kwargs = dict(use_norm=use_norm) # no attention type here
193+
kwargs = dict(
194+
use_norm=use_norm, interpolation_mode=interpolation_mode
195+
) # no attention type here
189196
blocks = [
190197
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
191198
if skip_ch > 0

segmentation_models_pytorch/decoders/manet/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class MAnet(SegmentationModel):
4848
```
4949
decoder_pab_channels: A number of channels for PAB module in decoder.
5050
Default is 64.
51+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
52+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
5153
in_channels: A number of input channels for the model, default is 3 (RGB images)
5254
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
5355
activation: An activation function to apply after the final convolution layer.
@@ -80,6 +82,7 @@ def __init__(
8082
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
8183
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
8284
decoder_pab_channels: int = 64,
85+
decoder_interpolation: str = "nearest",
8386
in_channels: int = 3,
8487
classes: int = 1,
8588
activation: Optional[Union[str, Callable]] = None,
@@ -111,6 +114,7 @@ def __init__(
111114
n_blocks=encoder_depth,
112115
use_norm=decoder_use_norm,
113116
pab_channels=decoder_pab_channels,
117+
interpolation_mode=decoder_interpolation,
114118
)
115119

116120
self.segmentation_head = SegmentationHead(

segmentation_models_pytorch/decoders/pan/decoder.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
165165

166166
class GAUBlock(nn.Module):
167167
def __init__(
168-
self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear"
168+
self,
169+
in_channels: int,
170+
out_channels: int,
171+
interpolation_mode: str = "bilinear",
169172
):
170173
super(GAUBlock, self).__init__()
171174

172-
self.upscale_mode = upscale_mode
173-
self.align_corners = True if upscale_mode == "bilinear" else None
175+
self.interpolation_mode = interpolation_mode
176+
self.align_corners = True if interpolation_mode == "bilinear" else None
174177

175178
self.conv1 = nn.Sequential(
176179
nn.AdaptiveAvgPool2d(1),
@@ -196,7 +199,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
196199
y_up = F.interpolate(
197200
y,
198201
size=(height, width),
199-
mode=self.upscale_mode,
202+
mode=self.interpolation_mode,
200203
align_corners=self.align_corners,
201204
)
202205
x = self.conv2(x)
@@ -211,7 +214,7 @@ def __init__(
211214
encoder_channels: Sequence[int],
212215
encoder_depth: Literal[3, 4, 5],
213216
decoder_channels: int,
214-
upscale_mode: str = "bilinear",
217+
interpolation_mode: str = "bilinear",
215218
):
216219
super().__init__()
217220

@@ -232,19 +235,19 @@ def __init__(
232235
self.gau3 = GAUBlock(
233236
in_channels=encoder_channels[2],
234237
out_channels=decoder_channels,
235-
upscale_mode=upscale_mode,
238+
interpolation_mode=interpolation_mode,
236239
)
237240
if encoder_depth >= 4:
238241
self.gau2 = GAUBlock(
239242
in_channels=encoder_channels[1],
240243
out_channels=decoder_channels,
241-
upscale_mode=upscale_mode,
244+
interpolation_mode=interpolation_mode,
242245
)
243246
if encoder_depth >= 3:
244247
self.gau1 = GAUBlock(
245248
in_channels=encoder_channels[0],
246249
out_channels=decoder_channels,
247-
upscale_mode=upscale_mode,
250+
interpolation_mode=interpolation_mode,
248251
)
249252

250253
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:

segmentation_models_pytorch/decoders/pan/model.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Callable, Literal, Optional, Union
2+
import warnings
23

34
from segmentation_models_pytorch.base import (
45
ClassificationHead,
@@ -30,6 +31,8 @@ class PAN(SegmentationModel):
3031
encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer.
3132
Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16.
3233
decoder_channels: A number of convolution layer filters in decoder blocks
34+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
35+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"bilinear"**.
3336
in_channels: A number of input channels for the model, default is 3 (RGB images)
3437
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3538
activation: An activation function to apply after the final convolution layer.
@@ -62,6 +65,7 @@ def __init__(
6265
encoder_weights: Optional[str] = "imagenet",
6366
encoder_output_stride: Literal[16, 32] = 16,
6467
decoder_channels: int = 32,
68+
decoder_interpolation: str = "bilinear",
6569
in_channels: int = 3,
6670
classes: int = 1,
6771
activation: Optional[Union[str, Callable]] = None,
@@ -78,6 +82,15 @@ def __init__(
7882
)
7983
)
8084

85+
upscale_mode = kwargs.pop("upscale_mode", None)
86+
if upscale_mode is not None:
87+
warnings.warn(
88+
"The usage of upscale_mode is deprecated. Please modify your code for decoder_interpolation",
89+
DeprecationWarning,
90+
stacklevel=2,
91+
)
92+
decoder_interpolation = upscale_mode
93+
8194
self.encoder = get_encoder(
8295
encoder_name,
8396
in_channels=in_channels,
@@ -91,6 +104,7 @@ def __init__(
91104
encoder_channels=self.encoder.out_channels,
92105
encoder_depth=encoder_depth,
93106
decoder_channels=decoder_channels,
107+
interpolation_mode=decoder_interpolation,
94108
)
95109

96110
self.segmentation_head = SegmentationHead(

segmentation_models_pytorch/decoders/unet/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Unet(SegmentationModel):
5858
```
5959
decoder_attention_type: Attention module used in decoder of the model. Available options are
6060
**None** and **scse** (https://arxiv.org/abs/1808.08127).
61-
decoder_interpolation_mode: Interpolation mode used in decoder of the model. Available options are
61+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
6262
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
6363
in_channels: A number of input channels for the model, default is 3 (RGB images)
6464
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
@@ -112,7 +112,7 @@ def __init__(
112112
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
113113
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
114114
decoder_attention_type: Optional[str] = None,
115-
decoder_interpolation_mode: str = "nearest",
115+
decoder_interpolation: str = "nearest",
116116
in_channels: int = 3,
117117
classes: int = 1,
118118
activation: Optional[Union[str, Callable]] = None,
@@ -147,7 +147,7 @@ def __init__(
147147
use_norm=decoder_use_norm,
148148
add_center_block=add_center_block,
149149
attention_type=decoder_attention_type,
150-
interpolation_mode=decoder_interpolation_mode,
150+
interpolation_mode=decoder_interpolation,
151151
)
152152

153153
self.segmentation_head = SegmentationHead(

segmentation_models_pytorch/decoders/unetplusplus/decoder.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def __init__(
1515
out_channels: int,
1616
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
1717
attention_type: Optional[str] = None,
18+
interpolation_mode: str = "nearest",
1819
):
1920
super().__init__()
2021
self.conv1 = md.Conv2dReLU(
@@ -35,11 +36,12 @@ def __init__(
3536
use_norm=use_norm,
3637
)
3738
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
39+
self.interpolation_mode = interpolation_mode
3840

3941
def forward(
4042
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
4143
) -> torch.Tensor:
42-
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
44+
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
4345
if skip is not None:
4446
x = torch.cat([x, skip], dim=1)
4547
x = self.attention1(x)
@@ -81,6 +83,7 @@ def __init__(
8183
n_blocks: int = 5,
8284
use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
8385
attention_type: Optional[str] = None,
86+
interpolation_mode: str = "nearest",
8487
center: bool = False,
8588
):
8689
super().__init__()
@@ -113,6 +116,7 @@ def __init__(
113116
kwargs = dict(
114117
use_norm=use_norm,
115118
attention_type=attention_type,
119+
interpolation_mode=interpolation_mode,
116120
)
117121

118122
blocks = {}

segmentation_models_pytorch/decoders/unetplusplus/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class UnetPlusPlus(SegmentationModel):
4747
```
4848
decoder_attention_type: Attention module used in decoder of the model.
4949
Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127).
50+
decoder_interpolation: Interpolation mode used in decoder of the model. Available options are
51+
**"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**.
5052
in_channels: A number of input channels for the model, default is 3 (RGB images)
5153
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
5254
activation: An activation function to apply after the final convolution layer.
@@ -81,6 +83,7 @@ def __init__(
8183
decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm",
8284
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
8385
decoder_attention_type: Optional[str] = None,
86+
decoder_interpolation: str = "nearest",
8487
in_channels: int = 3,
8588
classes: int = 1,
8689
activation: Optional[Union[str, Callable]] = None,
@@ -118,6 +121,7 @@ def __init__(
118121
use_norm=decoder_use_norm,
119122
center=True if encoder_name.startswith("vgg") else False,
120123
attention_type=decoder_attention_type,
124+
interpolation_mode=decoder_interpolation,
121125
)
122126

123127
self.segmentation_head = SegmentationHead(

tests/models/test_fpn.py

+23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
1+
import segmentation_models_pytorch as smp
2+
13
from tests.models import base
24

35

46
class TestFpnModel(base.BaseModelTester):
57
test_model_type = "fpn"
68
files_for_diff = [r"decoders/fpn/", r"base/"]
9+
10+
def test_interpolation(self):
11+
# test bilinear
12+
model_1 = smp.create_model(
13+
self.test_model_type,
14+
self.test_encoder_name,
15+
decoder_interpolation="bilinear",
16+
)
17+
assert model_1.decoder.p2.interpolation_mode == "bilinear"
18+
assert model_1.decoder.p3.interpolation_mode == "bilinear"
19+
assert model_1.decoder.p4.interpolation_mode == "bilinear"
20+
21+
# test bicubic
22+
model_2 = smp.create_model(
23+
self.test_model_type,
24+
self.test_encoder_name,
25+
decoder_interpolation="bicubic",
26+
)
27+
assert model_2.decoder.p2.interpolation_mode == "bicubic"
28+
assert model_2.decoder.p3.interpolation_mode == "bicubic"
29+
assert model_2.decoder.p4.interpolation_mode == "bicubic"

0 commit comments

Comments
 (0)