Skip to content

Deprecate use_batchnorm in favor of generalized use_norm parameter #1095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 71 additions & 12 deletions segmentation_models_pytorch/base/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import torch
import torch.nn as nn

Expand All @@ -16,11 +18,53 @@
padding=0,
stride=1,
use_batchnorm=True,
use_norm="batchnorm",
):
if use_batchnorm == "inplace" and InPlaceABN is None:
if use_batchnorm is not None:
warnings.warn(
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
DeprecationWarning,
)
if use_batchnorm is True:
use_norm = {"type": "batchnorm"}
elif use_batchnorm is False:
use_norm = {"type": "identity"}
elif use_batchnorm == "inplace":
use_norm = {

Check warning on line 33 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L32-L33

Added lines #L32 - L33 were not covered by tests
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
else:
raise ValueError("Unrecognized value for use_batchnorm")

Check warning on line 39 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L39

Added line #L39 was not covered by tests

if isinstance(use_norm, str):
norm_str = use_norm.lower()
if norm_str == "inplace":
use_norm = {

Check warning on line 44 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L42-L44

Added lines #L42 - L44 were not covered by tests
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in (

Check warning on line 49 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L49

Added line #L49 was not covered by tests
"batchnorm",
"identity",
"layernorm",
"groupnorm",
"instancenorm",
):
use_norm = {"type": norm_str}

Check warning on line 56 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L56

Added line #L56 was not covered by tests
else:
raise ValueError("Unrecognized normalization type string provided")

Check warning on line 58 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L58

Added line #L58 was not covered by tests
elif isinstance(use_norm, bool):
use_norm = {"type": "batchnorm" if use_norm else "identity"}

Check warning on line 60 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L60

Added line #L60 was not covered by tests
elif not isinstance(use_norm, dict):
raise ValueError("use_norm must be a dictionary, boolean, or string")

Check warning on line 62 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L62

Added line #L62 was not covered by tests

if use_norm["type"] == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
"To install see: https://github.com/mapillary/inplace_abn"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a separate function get_norm_layer which will validate input params and return norm layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Much simpler. Done the changes


conv = nn.Conv2d(
Expand All @@ -29,21 +73,30 @@
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
bias=use_norm["type"] != "inplace",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can initialize norm first and use a separate varaible

is_inplace_batchnorm = norm.__name__ == "InPlaceABN"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeap simpler

)
relu = nn.ReLU(inplace=True)

if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()

elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
norm_type = use_norm["type"]
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"}

if norm_type == "inplace":
norm = InPlaceABN(out_channels, **extra_kwargs)
relu = nn.Identity()

Check warning on line 85 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L84-L85

Added lines #L84 - L85 were not covered by tests
elif norm_type == "batchnorm":
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
elif norm_type == "identity":
norm = nn.Identity()
elif norm_type == "layernorm":
norm = nn.LayerNorm(out_channels, **extra_kwargs)
elif norm_type == "groupnorm":
norm = nn.GroupNorm(out_channels, **extra_kwargs)
elif norm_type == "instancenorm":
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)

Check warning on line 95 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L90-L95

Added lines #L90 - L95 were not covered by tests
else:
bn = nn.Identity()
raise ValueError(f"Unrecognized normalization type: {norm_type}")

Check warning on line 97 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L97

Added line #L97 was not covered by tests

super(Conv2dReLU, self).__init__(conv, bn, relu)
super(Conv2dReLU, self).__init__(conv, norm, relu)


class SCSEModule(nn.Module):
Expand Down Expand Up @@ -127,3 +180,9 @@

def forward(self, x):
return self.attention(x)


if __name__ == "__main__":
print(Conv2dReLU(3, 12, 4))
print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"}))
print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3}))

Check warning on line 188 in segmentation_models_pytorch/base/modules.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/modules.py#L186-L188

Added lines #L186 - L188 were not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed, instead, it would be nice to add a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, added some test around Conv2dReLu

32 changes: 26 additions & 6 deletions segmentation_models_pytorch/decoders/linknet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import torch
import torch.nn as nn

from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from segmentation_models_pytorch.base import modules


class TransposeX2(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have only use_norm here, use_bathcnorm should be replaced on top level (e.g. Unet model)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a proposition

):
super().__init__()
layers = [
nn.ConvTranspose2d(
Expand All @@ -15,14 +21,20 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
nn.ReLU(inplace=True),
]

if use_batchnorm:
if use_batchnorm or use_norm:
layers.insert(1, nn.BatchNorm2d(out_channels))

super().__init__(*layers)


class DecoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
def __init__(
self,
in_channels: int,
out_channels: int,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, let's move it to the model, and leave only use_norm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a proposition

super().__init__()

self.block = nn.Sequential(
Expand All @@ -31,6 +43,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
in_channels // 4,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
TransposeX2(
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
Expand All @@ -40,6 +53,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
out_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
)

Expand All @@ -58,7 +72,8 @@ def __init__(
encoder_channels: List[int],
prefinal_channels: int = 32,
n_blocks: int = 5,
use_batchnorm: bool = True,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()

Expand All @@ -71,7 +86,12 @@ def __init__(

self.blocks = nn.ModuleList(
[
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
DecoderBlock(
channels[i],
channels[i + 1],
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
for i in range(n_blocks)
]
)
Expand Down
26 changes: 23 additions & 3 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Union

from segmentation_models_pytorch.base import (
ClassificationHead,
Expand Down Expand Up @@ -29,9 +29,27 @@ class Linknet(SegmentationModel):
Default is 5
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
other pretrained weights (see table with available weights for each encoder_name)
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
Available options are **True, False, "inplace"**

**Note:** Deprecated, prefer using `decoder_use_norm` and set this to None.
decoder_use_norm: Specifies normalization between Conv2D and activation.
Accepts the following types:
- **True**: Defaults to `"batchnorm"`.
- **False**: No normalization (`nn.Identity`).
- **str**: Specifies normalization type using default parameters. Available values:
`"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`.
- **dict**: Fully customizable normalization settings. Structure:
```python
{"type": <norm_type>, **kwargs}
```
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.

**Example**:
```python
use_norm={"type": "groupnorm", "num_groups": 8}
```
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Expand Down Expand Up @@ -60,7 +78,8 @@ def __init__(
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_use_batchnorm: Union[bool, str, None] = True,
decoder_use_norm: Union[bool, str, Dict[str, Any]] = True,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
Expand All @@ -87,6 +106,7 @@ def __init__(
n_blocks=encoder_depth,
prefinal_channels=32,
use_batchnorm=decoder_use_batchnorm,
use_norm=decoder_use_norm,
)

self.segmentation_head = SegmentationHead(
Expand Down
28 changes: 21 additions & 7 deletions segmentation_models_pytorch/decoders/manet/decoder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional

from segmentation_models_pytorch.base import modules as md


Expand Down Expand Up @@ -49,7 +49,8 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
reduction: int = 16,
):
# MFABBlock is just a modified version of SE-blocks, one for skip, one for input
Expand All @@ -61,9 +62,14 @@ def __init__(
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
md.Conv2dReLU(
in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
in_channels,
skip_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
),
)
reduced_channels = max(1, skip_channels // reduction)
Expand All @@ -88,13 +94,15 @@ def __init__(
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -119,7 +127,8 @@ def __init__(
in_channels: int,
skip_channels: int,
out_channels: int,
use_batchnorm: bool = True,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
):
super().__init__()
self.conv1 = md.Conv2dReLU(
Expand All @@ -128,13 +137,15 @@ def __init__(
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
use_norm=use_norm,
)

def forward(
Expand All @@ -155,7 +166,8 @@ def __init__(
decoder_channels: List[int],
n_blocks: int = 5,
reduction: int = 16,
use_batchnorm: bool = True,
use_batchnorm: Union[bool, str, None] = True,
use_norm: Union[bool, str, Dict[str, Any]] = True,
pab_channels: int = 64,
):
super().__init__()
Expand All @@ -182,7 +194,9 @@ def __init__(
self.center = PABBlock(head_channels, pab_channels=pab_channels)

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
kwargs = dict(
use_batchnorm=use_batchnorm, use_norm=use_norm
) # no attention type here
blocks = [
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
if skip_ch > 0
Expand Down
Loading
Loading