Skip to content

Commit 7157501

Browse files
committed
Huge fix for torch scripting (except Unet++ and UperNet)
1 parent df2f484 commit 7157501

30 files changed

+602
-401
lines changed

segmentation_models_pytorch/base/hub_mixin.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import json
23
from pathlib import Path
34
from typing import Optional, Union
@@ -114,6 +115,7 @@ def save_pretrained(
114115
return result
115116

116117
@property
118+
@torch.jit.unused
117119
def config(self) -> dict:
118120
return self._hub_mixin_config
119121

segmentation_models_pytorch/base/model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
class SegmentationModel(torch.nn.Module, SMPHubMixin):
1212
"""Base class for all segmentation models."""
1313

14-
# if model supports shape not divisible by 2 ^ n
15-
# set to False
14+
# if model supports shape not divisible by 2 ^ n set to False
1615
requires_divisible_input_shape = True
1716

1817
# Fix type-hint for models, to avoid HubMixin signature
@@ -30,6 +29,9 @@ def check_input_shape(self, x):
3029
"""Check if the input shape is divisible by the output stride.
3130
If not, raise a RuntimeError.
3231
"""
32+
if not self.requires_divisible_input_shape:
33+
return
34+
3335
h, w = x.shape[-2:]
3436
output_stride = self.encoder.output_stride
3537
if h % output_stride != 0 or w % output_stride != 0:
@@ -51,15 +53,13 @@ def check_input_shape(self, x):
5153
def forward(self, x):
5254
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
5355

54-
if (
55-
not torch.jit.is_tracing()
56-
and not is_torch_compiling()
57-
and self.requires_divisible_input_shape
56+
if not (
57+
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
5858
):
5959
self.check_input_shape(x)
6060

6161
features = self.encoder(x)
62-
decoder_output = self.decoder(*features)
62+
decoder_output = self.decoder(features)
6363

6464
masks = self.segmentation_head(decoder_output)
6565

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

+39-18
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"""
3232

3333
from collections.abc import Iterable, Sequence
34-
from typing import Literal
34+
from typing import Literal, List
3535

3636
import torch
3737
from torch import nn
@@ -49,21 +49,42 @@ def __init__(
4949
aspp_separable: bool,
5050
aspp_dropout: float,
5151
):
52-
super().__init__(
53-
ASPP(
54-
in_channels,
55-
out_channels,
56-
atrous_rates,
57-
separable=aspp_separable,
58-
dropout=aspp_dropout,
59-
),
60-
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
61-
nn.BatchNorm2d(out_channels),
62-
nn.ReLU(),
52+
super().__init__()
53+
self.aspp = ASPP(
54+
in_channels,
55+
out_channels,
56+
atrous_rates,
57+
separable=aspp_separable,
58+
dropout=aspp_dropout,
6359
)
64-
65-
def forward(self, *features):
66-
return super().forward(features[-1])
60+
self.conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
61+
self.bn = nn.BatchNorm2d(out_channels)
62+
self.relu = nn.ReLU()
63+
64+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
65+
x = features[-1]
66+
x = self.aspp(x)
67+
x = self.conv(x)
68+
x = self.bn(x)
69+
x = self.relu(x)
70+
return x
71+
72+
def load_state_dict(self, state_dict, *args, **kwargs):
73+
# For backward compatibility, previously this module was Sequential
74+
# and was not scriptable.
75+
keys = list(state_dict.keys())
76+
for key in keys:
77+
new_key = key
78+
if key.startswith("0."):
79+
new_key = "aspp." + key[2:]
80+
elif key.startswith("1."):
81+
new_key = "conv." + key[2:]
82+
elif key.startswith("2."):
83+
new_key = "bn." + key[2:]
84+
elif key.startswith("3."):
85+
new_key = "relu." + key[2:]
86+
state_dict[new_key] = state_dict.pop(key)
87+
super().load_state_dict(state_dict, *args, **kwargs)
6788

6889

6990
class DeepLabV3PlusDecoder(nn.Module):
@@ -124,7 +145,7 @@ def __init__(
124145
nn.ReLU(),
125146
)
126147

127-
def forward(self, *features):
148+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
128149
aspp_features = self.aspp(features[-1])
129150
aspp_features = self.up(aspp_features)
130151
high_res_features = self.block1(features[2])
@@ -174,7 +195,7 @@ def __init__(self, in_channels: int, out_channels: int):
174195
nn.ReLU(),
175196
)
176197

177-
def forward(self, x):
198+
def forward(self, x: torch.Tensor) -> torch.Tensor:
178199
size = x.shape[-2:]
179200
for mod in self:
180201
x = mod(x)
@@ -216,7 +237,7 @@ def __init__(
216237
nn.Dropout(dropout),
217238
)
218239

219-
def forward(self, x):
240+
def forward(self, x: torch.Tensor) -> torch.Tensor:
220241
res = []
221242
for conv in self.convs:
222243
res.append(conv(x))

segmentation_models_pytorch/decoders/fpn/decoder.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5+
from typing import List, Literal
6+
57

68
class Conv3x3GNReLU(nn.Module):
7-
def __init__(self, in_channels, out_channels, upsample=False):
9+
def __init__(self, in_channels: int, out_channels: int, upsample: bool = False):
810
super().__init__()
911
self.upsample = upsample
1012
self.block = nn.Sequential(
@@ -15,27 +17,27 @@ def __init__(self, in_channels, out_channels, upsample=False):
1517
nn.ReLU(inplace=True),
1618
)
1719

18-
def forward(self, x):
20+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1921
x = self.block(x)
2022
if self.upsample:
21-
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
23+
x = F.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=True)
2224
return x
2325

2426

2527
class FPNBlock(nn.Module):
26-
def __init__(self, pyramid_channels, skip_channels):
28+
def __init__(self, pyramid_channels: int, skip_channels: int):
2729
super().__init__()
2830
self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1)
2931

30-
def forward(self, x, skip=None):
31-
x = F.interpolate(x, scale_factor=2, mode="nearest")
32+
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
33+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
3234
skip = self.skip_conv(skip)
3335
x = x + skip
3436
return x
3537

3638

3739
class SegmentationBlock(nn.Module):
38-
def __init__(self, in_channels, out_channels, n_upsamples=0):
40+
def __init__(self, in_channels: int, out_channels: int, n_upsamples: int = 0):
3941
super().__init__()
4042

4143
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
@@ -51,36 +53,37 @@ def forward(self, x):
5153

5254

5355
class MergeBlock(nn.Module):
54-
def __init__(self, policy):
56+
def __init__(self, policy: Literal["add", "cat"]):
5557
super().__init__()
5658
if policy not in ["add", "cat"]:
5759
raise ValueError(
5860
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy)
5961
)
6062
self.policy = policy
6163

62-
def forward(self, x):
64+
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
6365
if self.policy == "add":
64-
return sum(x)
66+
output = torch.stack(x).sum(dim=0)
6567
elif self.policy == "cat":
66-
return torch.cat(x, dim=1)
68+
output = torch.cat(x, dim=1)
6769
else:
6870
raise ValueError(
6971
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
7072
self.policy
7173
)
7274
)
75+
return output
7376

7477

7578
class FPNDecoder(nn.Module):
7679
def __init__(
7780
self,
78-
encoder_channels,
79-
encoder_depth=5,
80-
pyramid_channels=256,
81-
segmentation_channels=128,
82-
dropout=0.2,
83-
merge_policy="add",
81+
encoder_channels: List[int],
82+
encoder_depth: int = 5,
83+
pyramid_channels: int = 256,
84+
segmentation_channels: int = 128,
85+
dropout: float = 0.2,
86+
merge_policy: Literal["add", "cat"] = "add",
8487
):
8588
super().__init__()
8689

@@ -116,17 +119,20 @@ def __init__(
116119
self.merge = MergeBlock(merge_policy)
117120
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
118121

119-
def forward(self, *features):
122+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
120123
c2, c3, c4, c5 = features[-4:]
121124

122125
p5 = self.p5(c5)
123126
p4 = self.p4(p5, c4)
124127
p3 = self.p3(p4, c3)
125128
p2 = self.p2(p3, c2)
126129

127-
feature_pyramid = [
128-
seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])
129-
]
130+
s5 = self.seg_blocks[0](p5)
131+
s4 = self.seg_blocks[1](p4)
132+
s3 = self.seg_blocks[2](p3)
133+
s2 = self.seg_blocks[3](p2)
134+
135+
feature_pyramid = [s5, s4, s3, s2]
130136
x = self.merge(feature_pyramid)
131137
x = self.dropout(x)
132138

segmentation_models_pytorch/decoders/linknet/decoder.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import torch
12
import torch.nn as nn
23

4+
from typing import List, Optional
35
from segmentation_models_pytorch.base import modules
46

57

68
class TransposeX2(nn.Sequential):
7-
def __init__(self, in_channels, out_channels, use_batchnorm=True):
9+
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
810
super().__init__()
911
layers = [
1012
nn.ConvTranspose2d(
@@ -20,7 +22,7 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
2022

2123

2224
class DecoderBlock(nn.Module):
23-
def __init__(self, in_channels, out_channels, use_batchnorm=True):
25+
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
2426
super().__init__()
2527

2628
self.block = nn.Sequential(
@@ -41,7 +43,9 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
4143
),
4244
)
4345

44-
def forward(self, x, skip=None):
46+
def forward(
47+
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
48+
) -> torch.Tensor:
4549
x = self.block(x)
4650
if skip is not None:
4751
x = x + skip
@@ -50,7 +54,11 @@ def forward(self, x, skip=None):
5054

5155
class LinknetDecoder(nn.Module):
5256
def __init__(
53-
self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True
57+
self,
58+
encoder_channels: List[int],
59+
prefinal_channels: int = 32,
60+
n_blocks: int = 5,
61+
use_batchnorm: bool = True,
5462
):
5563
super().__init__()
5664

@@ -68,7 +76,7 @@ def __init__(
6876
]
6977
)
7078

71-
def forward(self, *features):
79+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
7280
features = features[1:] # remove first skip
7381
features = features[::-1] # reverse channels to start from head of encoder
7482

0 commit comments

Comments
 (0)