Skip to content

[feat] Adding SegFormer #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 29, 2024
Merged

[feat] Adding SegFormer #944

merged 10 commits into from
Nov 29, 2024

Conversation

brianhou0208
Copy link
Contributor

@brianhou0208 brianhou0208 commented Oct 12, 2024

Hi @qubvel ,
I have completed Segformer and passed the tests/test_models.py locally. Please check it.

Issue

resolve #941

Result

I’ve trained the model using the binary_segmentation_intro.ipynb, and the results are shown below.

  • Valid Dataset
Arch Backbone IoU (Per Image / Dataset)
FPN ResNet34 90.56 / 91.39
DeepLabV3Plus ResNet34 90.26 / 91.07
UNet ResNet34 90.50 / 91.31
SegFormer ResNet34 90.77 / 91.58
  • Test Dataset
Arch Backbone IoU (Per Image / Dataset)
FPN ResNet34 90.97 / 91.64
DeepLabV3Plus ResNet34 90.48 / 91.22
UNet ResNet34 91.15 / 91.76
SegFormer ResNet34 91.02 / 91.67

Reference

@brianhou0208 brianhou0208 marked this pull request as ready for review October 12, 2024 18:40
@brianhou0208 brianhou0208 mentioned this pull request Oct 12, 2024
@JulienMaille
Copy link
Contributor

Cool! May I ask how inference speed compares to UNet?

@brianhou0208
Copy link
Contributor Author

Hi @JulienMaille ,

Here is the result of model inference speed & FLOPS/MACs/Params.

Result

  • Model Backbone: ResNet34
  • Test Image Resolution (B, C, H, W): (1, 3, 128, 128)
  • Throughput Batch Size: 100
Model FLOPs (G) MACs (G) Params (M) Throughput (images/s)
PAN 3.7249 1.8601 21.4758 1278
FPN 3.4334 1.7138 23.1554 1027
LinkNet 2.5231 1.2575 21.7719 1313
PSPNet 1.1740 0.5851 21.438 3017
MANet 4.1655 2.0784 31.7836 921
DeepLabV3 13.6597 6.8251 26.0071 314
DeepLabV3+ 3.9478 1.9711 22.4375 1004
UNet 3.9142 1.9535 24.4364 1021
UNet++ 9.2050 4.5959 26.0786 446
UPerNet 4.1186 2.0551 22.6824 522
SegFormer 3.2657 1.6305 21.8765 727

calculate FLOPs/MACs/Params using calflops

Code

import torch
import segmentation_models_pytorch as smp
import time

T0 = 5
T1 = 10
def get_throughput(model, batch_size=100, resolution=128):
    model.eval()
    model.to('cuda')
    inputs = torch.randn(batch_size, 3, resolution, resolution, device='cuda')
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    start = time.time()
    while time.time() - start < T0:
        model(inputs)
    timing = []
    torch.cuda.synchronize()
    while sum(timing) < T1:
        start = time.time()
        model(inputs)
        torch.cuda.synchronize()
        timing.append(time.time() - start)
    timing = torch.as_tensor(timing, dtype=torch.float32)
    print(batch_size / timing.mean().item(), 'images/s @ batch size', batch_size)

model = smp.Unet(encoder_weights=None)
get_throughput(model)

reference code from RepViT

Environment

  • Platform: Google Colab
  • Device: Tesla T4
  • Version:
    • torch: 2.4.1+cu121
    • timm: 1.0.9
    • calflops: 0.3.2

@qubvel
Copy link
Collaborator

qubvel commented Oct 19, 2024

Hi @brianhou0208, thanks a lot for working on this model! We're you able to load original pretrained weights for the whole model and check the inference?

@brianhou0208
Copy link
Contributor Author

Hi @qubvel, I just tried it. It seems okay to load the original pretrained weights for the entire model.

Here are a few key points to note when loading the original pretrained weights:

  • encoder & decoder can be fully mapped without issues, only need to remap the names.
  • classes in the segmentation head must match that of the original pretrained model.
  • decoder_segmentation_channels are 768 for the B5 architecture, while they are 256 for other architectures.

Due to hardware limitations, I did not test whether the model outputs are exactly the same; I only checked if the weights were mapped correctly.

Loading the original petrained weights

import torch
import segmentation_models_pytorch as smp

original_checkpoint = torch.load(
    "./segformer.b0.512x1024.city.160k.pth", map_location="cpu", weights_only=False
)

num_classes = int(
    original_checkpoint["meta"]["config"].split("num_classes=")[1].split(",\n")[0]
)
decoder_dims = int(original_checkpoint["meta"]["config"].split("embed_dim=")[1][:3])
smp_state_dict = map_weights(original_checkpoint)
print(
    f"Pretrain Weight setting: num_classes={num_classes}, decoder_embed_dim={decoder_dims}"
)

model = smp.create_model(
    in_channels=3,
    classes=num_classes,
    arch="segformer",
    encoder_name="mit_b0",
    encoder_weights=None,
    decoder_segmentation_channels=decoder_dims,
).eval()
model.load_state_dict(smp_state_dict, strict=False)

x = torch.rand(1, 3, 512, 1024)
with torch.no_grad():
    y = model(x)

print(y.shape)
>> torch.Size([1, 19, 512, 1024])

Mapping weights

def map_weights(state_dict: dict):
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]

    mapped_state_dict = {
        # Map backbone to encoder
        k.replace("backbone", "encoder"): v
        for k, v in state_dict.items()
        if k.startswith("backbone")
    }

    # Map the linear_cX layers to MLP stages
    for i in range(4):
        base = f"decode_head.linear_c{i+1}.proj"
        mapped_state_dict[f"decoder.mlp_stage.{3-i}.linear.weight"] = state_dict[
            f"{base}.weight"
        ]
        mapped_state_dict[f"decoder.mlp_stage.{3-i}.linear.bias"] = state_dict[
            f"{base}.bias"
        ]

    # Map fuse_stage components
    fuse_base = "decode_head.linear_fuse"
    mapped_state_dict.update(
        {
            "decoder.fuse_stage.0.weight": state_dict[f"{fuse_base}.conv.weight"],
            "decoder.fuse_stage.1.weight": state_dict[f"{fuse_base}.bn.weight"],
            "decoder.fuse_stage.1.bias": state_dict[f"{fuse_base}.bn.bias"],
            "decoder.fuse_stage.1.running_mean": state_dict[
                f"{fuse_base}.bn.running_mean"
            ],
            "decoder.fuse_stage.1.running_var": state_dict[
                f"{fuse_base}.bn.running_var"
            ],
            "decoder.fuse_stage.1.num_batches_tracked": state_dict[
                f"{fuse_base}.bn.num_batches_tracked"
            ],
        }
    )

    # Map final layer components
    mapped_state_dict["segmentation_head.0.weight"] = state_dict[
        "decode_head.linear_pred.weight"
    ]
    mapped_state_dict["segmentation_head.0.bias"] = state_dict[
        "decode_head.linear_pred.bias"
    ]

    return mapped_state_dict

@qubvel
Copy link
Collaborator

qubvel commented Oct 21, 2024

Thanks! It would be great to have it tested against the original checkpoint or against the checkpoint on hugging face, or at least to make some inference and to make sure it outputs meaningful masks 😄

@brianhou0208
Copy link
Contributor Author

Here is example code to compare hugging face transformers and SMP

Example Code

import torch
import requests
from PIL import Image
import segmentation_models_pytorch as smp
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

def segformer_smp(inputs, path="./segformer.b0.512x1024.city.160k.pth"):
    original_checkpoint = torch.load(
        path, map_location="cpu", weights_only=False
    )

    num_classes = int(
        original_checkpoint["meta"]["config"].split("num_classes=")[1].split(",\n")[0]
    )
    decoder_dims = int(original_checkpoint["meta"]["config"].split("embed_dim=")[1][:3])
    smp_state_dict = map_weights(original_checkpoint)
    print(
        f"Pretrain Weight setting: num_classes={num_classes}, decoder_embed_dim={decoder_dims}"
    )

    model = smp.create_model(
        in_channels=3,
        classes=num_classes,
        arch="segformer",
        encoder_name="mit_b0",
        encoder_weights=None,
        decoder_segmentation_channels=decoder_dims,
    ).eval()
    model.load_state_dict(smp_state_dict, strict=False)
    with torch.no_grad():
        output = model(inputs['pixel_values'])
    output = torch.softmax(output, dim=1)
    output = torch.argmax(output, dim=1)
    return output
    
def segformer_tf(inputs, path):
    model = SegformerForSemanticSegmentation.from_pretrained(path)
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    logits = torch.nn.functional.interpolate(
        logits, scale_factor=4, mode='bilinear', align_corners=True)
    logits = torch.softmax(logits, dim=1)
    logits = torch.argmax(logits, dim=1)
    return logits
    
if __name__ == "__main__":
    model_path_hugging = "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
    url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
    image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
    processor = SegformerImageProcessor.from_pretrained(model_path_hugging)
    inputs = processor(images=image, return_tensors="pt")
    output_tf = segformer_tf(inputs, model_path_hugging).squeeze()
    output_smp = segformer_smp(inputs).squeeze()

    plt.subplot(131), plt.axis('off'), plt.imshow(image), plt.title('Input Image')
    plt.subplot(132), plt.axis('off'), plt.imshow(output_tf), plt.title('Output Transformer')
    plt.subplot(133), plt.axis('off'), plt.imshow(output_smp), plt.title('Output SMP')
    plt.tight_layout()
    plt.show()

Result

Figure_1
Figure_2
Figure_3

@qubvel
Copy link
Collaborator

qubvel commented Nov 6, 2024

Thanks for the update and sorry for the delay on my end! Outputs look identical! I will review this / next week and hopefully merge it! Thanks for the patience!

@omarequalmars
Copy link

omarequalmars commented Nov 28, 2024

The map_weights() function would be amazing if it could be integrated into from_pretrained so that it can deal with .ckpt/.pth files. Pytorch Lightning saves model checkpoints and last as .pth/.ckpt and I can't restore them with from_pretrained(). Maybe we can add a ckpt parameter to the from_pretrained() method such that if ckpt = True it will take the ckpt file, map the weights, and reproduce the output expected from torch.load_state_dict().

Edit: does map_weights() work for all the SMP models?

@qubvel qubvel self-requested a review November 29, 2024 14:26
Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks good now! I'm just going to add the conversion script to this PR as well and then merge it. Thanks a lot for your contribution, @brianhou0208, and I'm very sorry for the long wait!

@qubvel
Copy link
Collaborator

qubvel commented Nov 29, 2024

@omarequalmars indeed, might be very useful! Can you create a separate issue for this request?

@qubvel
Copy link
Collaborator

qubvel commented Nov 29, 2024

Here are the converted weights that can be directly loaded into segformer:
https://huggingface.co/collections/smp-hub/segformer-6749eb4923dea2c355f29a1f

@qubvel qubvel merged commit b93cf54 into qubvel-org:main Nov 29, 2024
12 checks passed
@brianhou0208 brianhou0208 deleted the segformer branch December 25, 2024 14:29
@qubvel
Copy link
Collaborator

qubvel commented Jan 8, 2025

@brianhou0208 I want to mention you for your contributions, do you have a LinkedIn/X account, if so, can you connect with me (see profile page)?

@brianhou0208
Copy link
Contributor Author

@qubvel here is my LinkedIn
It seems that I can't send you the message directly on it. Maybe I don't know how to use it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Segformer model
4 participants