-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[feat] Adding SegFormer #944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Cool! May I ask how inference speed compares to UNet? |
Hi @JulienMaille , Here is the result of model inference speed & FLOPS/MACs/Params. Result
calculate FLOPs/MACs/Params using calflops Codeimport torch
import segmentation_models_pytorch as smp
import time
T0 = 5
T1 = 10
def get_throughput(model, batch_size=100, resolution=128):
model.eval()
model.to('cuda')
inputs = torch.randn(batch_size, 3, resolution, resolution, device='cuda')
torch.cuda.empty_cache()
torch.cuda.synchronize()
start = time.time()
while time.time() - start < T0:
model(inputs)
timing = []
torch.cuda.synchronize()
while sum(timing) < T1:
start = time.time()
model(inputs)
torch.cuda.synchronize()
timing.append(time.time() - start)
timing = torch.as_tensor(timing, dtype=torch.float32)
print(batch_size / timing.mean().item(), 'images/s @ batch size', batch_size)
model = smp.Unet(encoder_weights=None)
get_throughput(model) reference code from RepViT Environment
|
Hi @brianhou0208, thanks a lot for working on this model! We're you able to load original pretrained weights for the whole model and check the inference? |
Hi @qubvel, I just tried it. It seems okay to load the original pretrained weights for the entire model. Here are a few key points to note when loading the original pretrained weights:
Due to hardware limitations, I did not test whether the model outputs are exactly the same; I only checked if the weights were mapped correctly. Loading the original petrained weights
import torch
import segmentation_models_pytorch as smp
original_checkpoint = torch.load(
"./segformer.b0.512x1024.city.160k.pth", map_location="cpu", weights_only=False
)
num_classes = int(
original_checkpoint["meta"]["config"].split("num_classes=")[1].split(",\n")[0]
)
decoder_dims = int(original_checkpoint["meta"]["config"].split("embed_dim=")[1][:3])
smp_state_dict = map_weights(original_checkpoint)
print(
f"Pretrain Weight setting: num_classes={num_classes}, decoder_embed_dim={decoder_dims}"
)
model = smp.create_model(
in_channels=3,
classes=num_classes,
arch="segformer",
encoder_name="mit_b0",
encoder_weights=None,
decoder_segmentation_channels=decoder_dims,
).eval()
model.load_state_dict(smp_state_dict, strict=False)
x = torch.rand(1, 3, 512, 1024)
with torch.no_grad():
y = model(x)
print(y.shape)
>> torch.Size([1, 19, 512, 1024]) Mapping weightsdef map_weights(state_dict: dict):
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
mapped_state_dict = {
# Map backbone to encoder
k.replace("backbone", "encoder"): v
for k, v in state_dict.items()
if k.startswith("backbone")
}
# Map the linear_cX layers to MLP stages
for i in range(4):
base = f"decode_head.linear_c{i+1}.proj"
mapped_state_dict[f"decoder.mlp_stage.{3-i}.linear.weight"] = state_dict[
f"{base}.weight"
]
mapped_state_dict[f"decoder.mlp_stage.{3-i}.linear.bias"] = state_dict[
f"{base}.bias"
]
# Map fuse_stage components
fuse_base = "decode_head.linear_fuse"
mapped_state_dict.update(
{
"decoder.fuse_stage.0.weight": state_dict[f"{fuse_base}.conv.weight"],
"decoder.fuse_stage.1.weight": state_dict[f"{fuse_base}.bn.weight"],
"decoder.fuse_stage.1.bias": state_dict[f"{fuse_base}.bn.bias"],
"decoder.fuse_stage.1.running_mean": state_dict[
f"{fuse_base}.bn.running_mean"
],
"decoder.fuse_stage.1.running_var": state_dict[
f"{fuse_base}.bn.running_var"
],
"decoder.fuse_stage.1.num_batches_tracked": state_dict[
f"{fuse_base}.bn.num_batches_tracked"
],
}
)
# Map final layer components
mapped_state_dict["segmentation_head.0.weight"] = state_dict[
"decode_head.linear_pred.weight"
]
mapped_state_dict["segmentation_head.0.bias"] = state_dict[
"decode_head.linear_pred.bias"
]
return mapped_state_dict |
Thanks! It would be great to have it tested against the original checkpoint or against the checkpoint on hugging face, or at least to make some inference and to make sure it outputs meaningful masks 😄 |
Here is example code to compare hugging face transformers and SMP Example Codeimport torch
import requests
from PIL import Image
import segmentation_models_pytorch as smp
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
def segformer_smp(inputs, path="./segformer.b0.512x1024.city.160k.pth"):
original_checkpoint = torch.load(
path, map_location="cpu", weights_only=False
)
num_classes = int(
original_checkpoint["meta"]["config"].split("num_classes=")[1].split(",\n")[0]
)
decoder_dims = int(original_checkpoint["meta"]["config"].split("embed_dim=")[1][:3])
smp_state_dict = map_weights(original_checkpoint)
print(
f"Pretrain Weight setting: num_classes={num_classes}, decoder_embed_dim={decoder_dims}"
)
model = smp.create_model(
in_channels=3,
classes=num_classes,
arch="segformer",
encoder_name="mit_b0",
encoder_weights=None,
decoder_segmentation_channels=decoder_dims,
).eval()
model.load_state_dict(smp_state_dict, strict=False)
with torch.no_grad():
output = model(inputs['pixel_values'])
output = torch.softmax(output, dim=1)
output = torch.argmax(output, dim=1)
return output
def segformer_tf(inputs, path):
model = SegformerForSemanticSegmentation.from_pretrained(path)
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
logits = torch.nn.functional.interpolate(
logits, scale_factor=4, mode='bilinear', align_corners=True)
logits = torch.softmax(logits, dim=1)
logits = torch.argmax(logits, dim=1)
return logits
if __name__ == "__main__":
model_path_hugging = "nvidia/segformer-b0-finetuned-cityscapes-512-1024"
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
processor = SegformerImageProcessor.from_pretrained(model_path_hugging)
inputs = processor(images=image, return_tensors="pt")
output_tf = segformer_tf(inputs, model_path_hugging).squeeze()
output_smp = segformer_smp(inputs).squeeze()
plt.subplot(131), plt.axis('off'), plt.imshow(image), plt.title('Input Image')
plt.subplot(132), plt.axis('off'), plt.imshow(output_tf), plt.title('Output Transformer')
plt.subplot(133), plt.axis('off'), plt.imshow(output_smp), plt.title('Output SMP')
plt.tight_layout()
plt.show() Result |
Thanks for the update and sorry for the delay on my end! Outputs look identical! I will review this / next week and hopefully merge it! Thanks for the patience! |
The map_weights() function would be amazing if it could be integrated into from_pretrained so that it can deal with .ckpt/.pth files. Pytorch Lightning saves model checkpoints and last as .pth/.ckpt and I can't restore them with from_pretrained(). Maybe we can add a ckpt parameter to the from_pretrained() method such that if ckpt = True it will take the ckpt file, map the weights, and reproduce the output expected from torch.load_state_dict(). Edit: does map_weights() work for all the SMP models? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good now! I'm just going to add the conversion script to this PR as well and then merge it. Thanks a lot for your contribution, @brianhou0208, and I'm very sorry for the long wait!
@omarequalmars indeed, might be very useful! Can you create a separate issue for this request? |
Here are the converted weights that can be directly loaded into segformer: |
@brianhou0208 I want to mention you for your contributions, do you have a LinkedIn/X account, if so, can you connect with me (see profile page)? |
Hi @qubvel ,
I have completed Segformer and passed the
tests/test_models.py
locally. Please check it.Issue
resolve #941
Result
I’ve trained the model using the binary_segmentation_intro.ipynb, and the results are shown below.
Reference