Skip to content

[feat] Adding SegFormer #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
The main features of this library are:

- High-level API (just two lines to create a neural network)
- 10 models architectures for binary and multi class segmentation (including legendary Unet)
- 11 models architectures for binary and multi class segmentation (including legendary Unet)
- 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
- All encoders have pre-trained weights for faster and better convergence
- Popular metrics and losses for training routines
Expand Down Expand Up @@ -95,6 +95,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
- Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)]

#### Encoders <a name="encoders"></a>

Expand Down
8 changes: 8 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,11 @@ PAN
UPerNet
~~~~~~~
.. autoclass:: segmentation_models_pytorch.UPerNet


.. _segformer:

Segformer
~~~~~~~~~
.. autoclass:: segmentation_models_pytorch.Segformer

131 changes: 131 additions & 0 deletions examples/segformer_inference_pretrained.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand Down Expand Up @@ -50,6 +51,7 @@ def create_model(
DeepLabV3Plus,
PAN,
UPerNet,
Segformer,
]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
Expand Down Expand Up @@ -85,6 +87,7 @@ def create_model(
"DeepLabV3Plus",
"PAN",
"UPerNet",
"Segformer",
"from_pretrained",
"create_model",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import Segformer

__all__ = ["Segformer"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch
import argparse
import requests
import numpy as np
import huggingface_hub
import albumentations as A
import matplotlib.pyplot as plt

from PIL import Image
import segmentation_models_pytorch as smp


def convert_state_dict_to_smp(state_dict: dict):
# fmt: off

if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]

new_state_dict = {}

# Map the backbone components to the encoder
keys = list(state_dict.keys())
for key in keys:
if key.startswith("backbone"):
new_key = key.replace("backbone", "encoder")
new_state_dict[new_key] = state_dict.pop(key)


# Map the linear_cX layers to MLP stages
for i in range(4):
base = f"decode_head.linear_c{i+1}.proj"
new_state_dict[f"decoder.mlp_stage.{3-i}.linear.weight"] = state_dict.pop(f"{base}.weight")
new_state_dict[f"decoder.mlp_stage.{3-i}.linear.bias"] = state_dict.pop(f"{base}.bias")

# Map fuse_stage components
fuse_base = "decode_head.linear_fuse"
fuse_weights = {
"decoder.fuse_stage.0.weight": state_dict.pop(f"{fuse_base}.conv.weight"),
"decoder.fuse_stage.1.weight": state_dict.pop(f"{fuse_base}.bn.weight"),
"decoder.fuse_stage.1.bias": state_dict.pop(f"{fuse_base}.bn.bias"),
"decoder.fuse_stage.1.running_mean": state_dict.pop(f"{fuse_base}.bn.running_mean"),
"decoder.fuse_stage.1.running_var": state_dict.pop(f"{fuse_base}.bn.running_var"),
"decoder.fuse_stage.1.num_batches_tracked": state_dict.pop(f"{fuse_base}.bn.num_batches_tracked"),
}
new_state_dict.update(fuse_weights)

# Map final layer components
new_state_dict["segmentation_head.0.weight"] = state_dict.pop("decode_head.linear_pred.weight")
new_state_dict["segmentation_head.0.bias"] = state_dict.pop("decode_head.linear_pred.bias")

del state_dict["decode_head.conv_seg.weight"]
del state_dict["decode_head.conv_seg.bias"]

assert len(state_dict) == 0, f"Unmapped keys: {state_dict.keys()}"

# fmt: on
return new_state_dict


def get_np_image():
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return np.array(image)


def main(args):
original_checkpoint = torch.load(args.path, map_location="cpu", weights_only=True)
smp_state_dict = convert_state_dict_to_smp(original_checkpoint)

config = original_checkpoint["meta"]["config"]
num_classes = int(config.split("num_classes=")[1].split(",\n")[0])
decoder_dims = int(config.split("embed_dim=")[1].split(",\n")[0])
height, width = [
int(x) for x in config.split("crop_size=(")[1].split("), ")[0].split(",")
]
model_size = args.path.split("segformer.")[1][:2]

# Create the model
model = smp.create_model(
in_channels=3,
classes=num_classes,
arch="segformer",
encoder_name=f"mit_{model_size}",
encoder_weights=None,
decoder_segmentation_channels=decoder_dims,
).eval()

# Load the converted state dict
model.load_state_dict(smp_state_dict, strict=True)

# Preprocessing params
preprocessing = A.Compose(
[
A.Resize(height, width, p=1),
A.Normalize(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
max_pixel_value=1.0,
p=1,
),
]
)

# Prepare the input
image = get_np_image()
normalized_image = preprocessing(image=image)["image"]
tensor = torch.tensor(normalized_image).permute(2, 0, 1).unsqueeze(0).float()

# Forward pass
with torch.no_grad():
mask = model(tensor)

# Postprocessing
mask = torch.nn.functional.interpolate(
mask, size=(image.shape[0], image.shape[1]), mode="bilinear"
)
mask = torch.argmax(mask, dim=1)
mask = mask.squeeze().cpu().numpy()

model_name = args.path.split("/")[-1].replace(".pth", "").replace(".", "-")

model.save_pretrained(model_name)
preprocessing.save_pretrained(model_name)

# fmt: off
plt.subplot(121), plt.axis('off'), plt.imshow(image), plt.title('Input Image')
plt.subplot(122), plt.axis('off'), plt.imshow(mask), plt.title('Output Mask')
plt.savefig(f"{model_name}/example_mask.png")
# fmt: on

if args.push_to_hub:
repo_id = f"smp-hub/{model_name}"
api = huggingface_hub.HfApi()
api.create_repo(repo_id=repo_id, repo_type="model")
api.upload_folder(folder_path=model_name, repo_id=repo_id)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--path",
type=str,
default="weights/trained_models/segformer.b2.512x512.ade.160k.pth",
)
parser.add_argument("--push_to_hub", action="store_true")
args = parser.parse_args()

main(args)
72 changes: 72 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from segmentation_models_pytorch.base import modules as md


class MLP(nn.Module):
def __init__(self, skip_channels, segmentation_channels):
super().__init__()

self.linear = nn.Linear(skip_channels, segmentation_channels)

def forward(self, x: torch.Tensor):
batch, _, height, width = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.linear(x)
x = x.transpose(1, 2).reshape(batch, -1, height, width).contiguous()
return x


class SegformerDecoder(nn.Module):
def __init__(
self,
encoder_channels,
encoder_depth=5,
segmentation_channels=256,
):
super().__init__()

if encoder_depth < 3:
raise ValueError(
"Encoder depth for Segformer decoder cannot be less than 3, got {}.".format(
encoder_depth
)
)

if encoder_channels[1] == 0:
encoder_channels = tuple(
channel for index, channel in enumerate(encoder_channels) if index != 1
)
encoder_channels = encoder_channels[::-1]

self.mlp_stage = nn.ModuleList(
[MLP(channel, segmentation_channels) for channel in encoder_channels[:-1]]
)

self.fuse_stage = md.Conv2dReLU(
in_channels=(len(encoder_channels) - 1) * segmentation_channels,
out_channels=segmentation_channels,
kernel_size=1,
use_batchnorm=True,
)

def forward(self, *features):
# Resize all features to the size of the largest feature
target_size = [dim // 4 for dim in features[0].shape[2:]]

features = features[2:] if features[1].size(1) == 0 else features[1:]
features = features[::-1] # reverse channels to start from head of encoder

resized_features = []
for feature, stage in zip(features, self.mlp_stage):
feature = stage(feature)
resized_feature = F.interpolate(
feature, size=target_size, mode="bilinear", align_corners=False
)
resized_features.append(resized_feature)

output = self.fuse_stage(torch.cat(resized_features, dim=1))

return output
93 changes: 93 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Any, Optional, Union, Callable

from segmentation_models_pytorch.base import (
ClassificationHead,
SegmentationHead,
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder

from .decoder import SegformerDecoder


class Segformer(SegmentationModel):
"""Segformer is simple and efficient design for semantic segmentation with Transformers

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 256
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.

Returns:
``torch.nn.Module``: **Segformer**

.. _Segformer:
https://arxiv.org/abs/2105.15203

"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_segmentation_channels: int = 256,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, Callable]] = None,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)

self.decoder = SegformerDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
segmentation_channels=decoder_segmentation_channels,
)

self.segmentation_head = SegmentationHead(
in_channels=decoder_segmentation_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=4,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "segformer-{}".format(encoder_name)
self.initialize()
Loading