Skip to content

Enable any resolution for Unet #1029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from typing import TypeVar, Type

from . import initialization as init
from .hub_mixin import SMPHubMixin

T = TypeVar("T", bound="SegmentationModel")


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""
Expand All @@ -11,6 +14,11 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin):
# set to False
requires_divisible_input_shape = True

# Fix type-hint for models, to avoid HubMixin signature
def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
Expand Down Expand Up @@ -42,7 +50,7 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
if torch.jit.is_tracing() or self.requires_divisible_input_shape:
self.check_input_shape(x)

features = self.encoder(x)
Expand Down
95 changes: 62 additions & 33 deletions segmentation_models_pytorch/decoders/unet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional, Sequence
from segmentation_models_pytorch.base import modules as md


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True,
attention_type=None,
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
interpolation_mode: str = "nearest",
):
super().__init__()
self.interpolate_mode = interpolation_mode
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
Expand All @@ -34,19 +37,32 @@ def __init__(
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
def forward(
self,
feature_map: torch.Tensor,
target_height: int,
target_width: int,
Comment on lines +45 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an intentional backwards-incompatible change? It is no longer possible to use U-Net without specifying height/width. Could we instead default to the same height/width as the input like we previously did?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure I got this.. this is just for a layer, but Decoder pass height and width. Can you please specify what is broken?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the problem was that I'm directly using smp.decoders.unet.decoder.UnetDecoder. I got this failure in CI when I tried the main branch of SMP: https://github.com/microsoft/torchgeo/actions/runs/14042455085/job/39315632706?pr=2669. Perhaps we are relying on some undocumented implementation details in https://github.com/microsoft/torchgeo/blob/main/torchgeo/models/fcsiam.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of the fails are related to modified decoder forward, previously input features were unpacked with *features while passing into the decoder's forward, but now the star is removed and they should be provided as list

Copy link
Collaborator Author

@qubvel qubvel Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's easy to support the new API, although it's harder to support both. I guess the question is whether this is intentional or not. Maybe it would help to add a "backwards-incompatible" label to PRs like this. Even better would be to deprecate the old syntax before completely removing it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that was intentional, added in PR which fix support of torch.script/compile/export. These are internal modules, so I hope it will not break too many things.. But you are right, I will add the label and I will also highlight it in release notes

skip_connection: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Upsample feature map to the given spatial shape, concatenate with skip connection,
apply attention block (if specified) and then apply two convolutions.
"""
feature_map = F.interpolate(
feature_map, size=(target_height, target_width), mode=self.interpolate_mode
)
if skip_connection is not None:
feature_map = torch.cat([feature_map, skip_connection], dim=1)
feature_map = self.attention1(feature_map)
feature_map = self.conv1(feature_map)
feature_map = self.conv2(feature_map)
feature_map = self.attention2(feature_map)
return feature_map


class CenterBlock(nn.Sequential):
def __init__(self, in_channels, out_channels, use_batchnorm=True):
"""Center block of the Unet decoder. Applied to the last feature map of the encoder."""

def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
conv1 = md.Conv2dReLU(
in_channels,
out_channels,
Expand All @@ -67,12 +83,12 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
class UnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
use_batchnorm=True,
attention_type=None,
center=False,
encoder_channels: Sequence[int],
decoder_channels: Sequence[int],
n_blocks: int = 5,
use_batchnorm: bool = True,
attention_type: Optional[str] = None,
add_center_block: bool = False,
):
super().__init__()

Expand All @@ -94,31 +110,44 @@ def __init__(
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

if center:
if add_center_block:
self.center = CenterBlock(
head_channels, head_channels, use_batchnorm=use_batchnorm
)
else:
self.center = nn.Identity()

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
blocks = [
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
self.blocks = nn.ModuleList(blocks)

def forward(self, *features):
self.blocks = nn.ModuleList()
for block_in_channels, block_skip_channels, block_out_channels in zip(
in_channels, skip_channels, out_channels
):
block = DecoderBlock(
block_in_channels,
block_skip_channels,
block_out_channels,
use_batchnorm=use_batchnorm,
attention_type=attention_type,
)
self.blocks.append(block)

def forward(self, *features: torch.Tensor) -> torch.Tensor:
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
spatial_shapes = [feature.shape[2:] for feature in features]
spatial_shapes = spatial_shapes[::-1]

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

head = features[0]
skips = features[1:]
skip_connections = features[1:]

x = self.center(head)

for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)
# upsample to the next spatial shape
height, width = spatial_shapes[i + 1]
skip_connection = skip_connections[i] if i < len(skip_connections) else None
x = decoder_block(x, height, width, skip_connection=skip_connection)

return x
46 changes: 39 additions & 7 deletions segmentation_models_pytorch/decoders/unet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union, Tuple, Callable
from typing import Any, Optional, Union, Callable, Sequence

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand All @@ -12,10 +12,21 @@


class Unet(SegmentationModel):
"""Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
for fusing decoder blocks with skip connections.
"""
U-Net is a fully convolutional neural network architecture designed for semantic image segmentation.

It consists of two main parts:

1. An encoder (downsampling path) that extracts increasingly abstract features
2. A decoder (upsampling path) that gradually recovers spatial details

The key is the use of skip connections between corresponding encoder and decoder layers.
These connections allow the decoder to access fine-grained details from earlier encoder layers,
which helps produce more precise segmentation masks.

The skip connections work by concatenating feature maps from the encoder directly into the decoder
at corresponding resolutions. This helps preserve important spatial information that would
otherwise be lost during the encoding process.

Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
Expand Down Expand Up @@ -51,19 +62,39 @@ class Unet(SegmentationModel):
Returns:
``torch.nn.Module``: Unet

Example:
.. code-block:: python

import torch
import segmentation_models_pytorch as smp

model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5)
model.eval()

# generate random images
images = torch.rand(2, 3, 256, 256)

with torch.inference_mode():
mask = model(images)

print(mask.shape)
# torch.Size([2, 5, 256, 256])

.. _Unet:
https://arxiv.org/abs/1505.04597

"""

requires_divisible_input_shape = False

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
decoder_channels: Sequence[int] = (256, 128, 64, 32, 16),
decoder_attention_type: Optional[str] = None,
in_channels: int = 3,
classes: int = 1,
Expand All @@ -81,12 +112,13 @@ def __init__(
**kwargs,
)

add_center_block = encoder_name.startswith("vgg")
self.decoder = UnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
center=True if encoder_name.startswith("vgg") else False,
add_center_block=add_center_block,
attention_type=decoder_attention_type,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_in_channels(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
encoder.forward(sample)

def test_depth(self):
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_depth(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

# check number of features
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_dilated(self):
encoder.eval()

# forward
with torch.no_grad():
with torch.inference_mode():
features = encoder.forward(sample)

height_strides, width_strides = self.get_features_output_strides(
Expand Down
Loading
Loading