Skip to content

Commit 8c91c09

Browse files
committed
add SegFormer
1 parent 8380b15 commit 8c91c09

File tree

4 files changed

+167
-0
lines changed

4 files changed

+167
-0
lines changed

segmentation_models_pytorch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
1616
from .decoders.pan import PAN
1717
from .decoders.upernet import UPerNet
18+
from .decoders.segformer import Segformer
1819
from .base.hub_mixin import from_pretrained
1920

2021
from .__version__ import __version__
@@ -50,6 +51,7 @@ def create_model(
5051
DeepLabV3Plus,
5152
PAN,
5253
UPerNet,
54+
Segformer,
5355
]
5456
archs_dict = {a.__name__.lower(): a for a in archs}
5557
try:
@@ -85,6 +87,7 @@ def create_model(
8587
"DeepLabV3Plus",
8688
"PAN",
8789
"UPerNet",
90+
"Segformer",
8891
"from_pretrained",
8992
"create_model",
9093
"__version__",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import Segformer
2+
3+
__all__ = ["Segformer"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from segmentation_models_pytorch.base import modules as md
6+
7+
8+
class MLP(nn.Module):
9+
def __init__(self, skip_channels, segmentation_channels):
10+
super().__init__()
11+
12+
self.linear = nn.Linear(skip_channels, segmentation_channels)
13+
14+
def forward(self, x: torch.Tensor):
15+
batch, _, height, width = x.shape
16+
x = x.flatten(2).transpose(1, 2)
17+
x = self.linear(x)
18+
x = x.transpose(1, 2).reshape(batch, -1, height, width).contiguous()
19+
return x
20+
21+
22+
class SegformerDecoder(nn.Module):
23+
def __init__(
24+
self,
25+
encoder_channels,
26+
encoder_depth=5,
27+
segmentation_channels=256,
28+
):
29+
super().__init__()
30+
31+
if encoder_depth < 3:
32+
raise ValueError(
33+
"Encoder depth for Segformer decoder cannot be less than 3, got {}.".format(
34+
encoder_depth
35+
)
36+
)
37+
38+
if encoder_channels[1] == 0:
39+
encoder_channels = tuple(
40+
channel for index, channel in enumerate(encoder_channels) if index != 1
41+
)
42+
encoder_channels = encoder_channels[::-1]
43+
44+
self.mlp_stage = nn.ModuleList(
45+
[MLP(channel, segmentation_channels) for channel in encoder_channels[:-1]]
46+
)
47+
48+
self.fuse_stage = md.Conv2dReLU(
49+
in_channels=(len(encoder_channels) - 1) * segmentation_channels,
50+
out_channels=segmentation_channels,
51+
kernel_size=1,
52+
use_batchnorm=True,
53+
)
54+
55+
def forward(self, *features):
56+
# Resize all features to the size of the largest feature
57+
target_size = [dim // 4 for dim in features[0].shape[2:]]
58+
59+
features = features[2:] if features[1].size(1) == 0 else features[1:]
60+
features = features[::-1] # reverse channels to start from head of encoder
61+
62+
resized_features = []
63+
for feature, stage in zip(features, self.mlp_stage):
64+
feature = stage(feature)
65+
resized_feature = F.interpolate(
66+
feature, size=target_size, mode="bilinear", align_corners=False
67+
)
68+
resized_features.append(resized_feature)
69+
70+
output = self.fuse_stage(torch.cat(resized_features, dim=1))
71+
72+
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import Optional, Union
2+
3+
from segmentation_models_pytorch.encoders import get_encoder
4+
from segmentation_models_pytorch.base import (
5+
SegmentationModel,
6+
SegmentationHead,
7+
ClassificationHead,
8+
)
9+
from .decoder import SegformerDecoder
10+
11+
12+
class Segformer(SegmentationModel):
13+
"""Segformer is simple and efficient design for semantic segmentation with Transformers
14+
15+
Args:
16+
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
17+
to extract features of different spatial resolution
18+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
19+
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
20+
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
21+
Default is 5
22+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
23+
other pretrained weights (see table with available weights for each encoder_name)
24+
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 256
25+
in_channels: A number of input channels for the model, default is 3 (RGB images)
26+
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
27+
activation: An activation function to apply after the final convolution layer.
28+
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
29+
**callable** and **None**.
30+
Default is **None**
31+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
32+
on top of encoder if **aux_params** is not **None** (default). Supported params:
33+
- classes (int): A number of classes
34+
- pooling (str): One of "max", "avg". Default is "avg"
35+
- dropout (float): Dropout factor in [0, 1)
36+
- activation (str): An activation function to apply "sigmoid"/"softmax"
37+
(could be **None** to return logits)
38+
39+
Returns:
40+
``torch.nn.Module``: **Segformer**
41+
42+
.. _Segformer:
43+
https://arxiv.org/abs/2105.15203
44+
45+
"""
46+
47+
def __init__(
48+
self,
49+
encoder_name: str = "resnet34",
50+
encoder_depth: int = 5,
51+
encoder_weights: Optional[str] = "imagenet",
52+
decoder_segmentation_channels: int = 256,
53+
in_channels: int = 3,
54+
classes: int = 1,
55+
activation: Optional[Union[str, callable]] = None,
56+
aux_params: Optional[dict] = None,
57+
):
58+
super().__init__()
59+
60+
self.encoder = get_encoder(
61+
encoder_name,
62+
in_channels=in_channels,
63+
depth=encoder_depth,
64+
weights=encoder_weights,
65+
)
66+
67+
self.decoder = SegformerDecoder(
68+
encoder_channels=self.encoder.out_channels,
69+
encoder_depth=encoder_depth,
70+
segmentation_channels=decoder_segmentation_channels,
71+
)
72+
73+
self.segmentation_head = SegmentationHead(
74+
in_channels=decoder_segmentation_channels,
75+
out_channels=classes,
76+
activation=activation,
77+
kernel_size=3,
78+
upsampling=4,
79+
)
80+
81+
if aux_params is not None:
82+
self.classification_head = ClassificationHead(
83+
in_channels=self.encoder.out_channels[-1], **aux_params
84+
)
85+
else:
86+
self.classification_head = None
87+
88+
self.name = "segformer-{}".format(encoder_name)
89+
self.initialize()

0 commit comments

Comments
 (0)