Skip to content

Commit e26adcd

Browse files
Deprecate use_batchnorm in favor of generalized use_norm parameter
1 parent c5d80bd commit e26adcd

File tree

9 files changed

+257
-48
lines changed

9 files changed

+257
-48
lines changed

segmentation_models_pytorch/base/modules.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import torch
24
import torch.nn as nn
35

@@ -16,11 +18,53 @@ def __init__(
1618
padding=0,
1719
stride=1,
1820
use_batchnorm=True,
21+
use_norm="batchnorm",
1922
):
20-
if use_batchnorm == "inplace" and InPlaceABN is None:
23+
if use_batchnorm is not None:
24+
warnings.warn(
25+
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
26+
DeprecationWarning,
27+
)
28+
if use_batchnorm is True:
29+
use_norm = {"type": "batchnorm"}
30+
elif use_batchnorm is False:
31+
use_norm = {"type": "identity"}
32+
elif use_batchnorm == "inplace":
33+
use_norm = {
34+
"type": "inplace",
35+
"activation": "leaky_relu",
36+
"activation_param": 0.0,
37+
}
38+
else:
39+
raise ValueError("Unrecognized value for use_batchnorm")
40+
41+
if isinstance(use_norm, str):
42+
norm_str = use_norm.lower()
43+
if norm_str == "inplace":
44+
use_norm = {
45+
"type": "inplace",
46+
"activation": "leaky_relu",
47+
"activation_param": 0.0,
48+
}
49+
elif norm_str in (
50+
"batchnorm",
51+
"identity",
52+
"layernorm",
53+
"groupnorm",
54+
"instancenorm",
55+
):
56+
use_norm = {"type": norm_str}
57+
else:
58+
raise ValueError("Unrecognized normalization type string provided")
59+
elif isinstance(use_norm, bool):
60+
use_norm = {"type": "batchnorm" if use_norm else "identity"}
61+
elif not isinstance(use_norm, dict):
62+
raise ValueError("use_norm must be a dictionary, boolean, or string")
63+
64+
if use_norm["type"] == "inplace" and InPlaceABN is None:
2165
raise RuntimeError(
22-
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
23-
+ "To install see: https://github.com/mapillary/inplace_abn"
66+
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
67+
"To install see: https://github.com/mapillary/inplace_abn"
2468
)
2569

2670
conv = nn.Conv2d(
@@ -29,21 +73,30 @@ def __init__(
2973
kernel_size,
3074
stride=stride,
3175
padding=padding,
32-
bias=not (use_batchnorm),
76+
bias=use_norm["type"] != "inplace",
3377
)
3478
relu = nn.ReLU(inplace=True)
3579

36-
if use_batchnorm == "inplace":
37-
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
38-
relu = nn.Identity()
39-
40-
elif use_batchnorm and use_batchnorm != "inplace":
41-
bn = nn.BatchNorm2d(out_channels)
80+
norm_type = use_norm["type"]
81+
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"}
4282

83+
if norm_type == "inplace":
84+
norm = InPlaceABN(out_channels, **extra_kwargs)
85+
relu = nn.Identity()
86+
elif norm_type == "batchnorm":
87+
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
88+
elif norm_type == "identity":
89+
norm = nn.Identity()
90+
elif norm_type == "layernorm":
91+
norm = nn.LayerNorm(out_channels, **extra_kwargs)
92+
elif norm_type == "groupnorm":
93+
norm = nn.GroupNorm(out_channels, **extra_kwargs)
94+
elif norm_type == "instancenorm":
95+
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
4396
else:
44-
bn = nn.Identity()
97+
raise ValueError(f"Unrecognized normalization type: {norm_type}")
4598

46-
super(Conv2dReLU, self).__init__(conv, bn, relu)
99+
super(Conv2dReLU, self).__init__(conv, norm, relu)
47100

48101

49102
class SCSEModule(nn.Module):
@@ -127,3 +180,9 @@ def __init__(self, name, **params):
127180

128181
def forward(self, x):
129182
return self.attention(x)
183+
184+
185+
if __name__ == "__main__":
186+
print(Conv2dReLU(3, 12, 4))
187+
print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"}))
188+
print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3}))

segmentation_models_pytorch/decoders/linknet/decoder.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import torch
22
import torch.nn as nn
33

4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional, Union
55
from segmentation_models_pytorch.base import modules
66

77

88
class TransposeX2(nn.Sequential):
9-
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
9+
def __init__(
10+
self,
11+
in_channels: int,
12+
out_channels: int,
13+
use_batchnorm: Union[bool, str, None] = True,
14+
use_norm: Union[bool, str, Dict[str, Any]] = True,
15+
):
1016
super().__init__()
1117
layers = [
1218
nn.ConvTranspose2d(
@@ -15,14 +21,20 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
1521
nn.ReLU(inplace=True),
1622
]
1723

18-
if use_batchnorm:
24+
if use_batchnorm or use_norm:
1925
layers.insert(1, nn.BatchNorm2d(out_channels))
2026

2127
super().__init__(*layers)
2228

2329

2430
class DecoderBlock(nn.Module):
25-
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True):
31+
def __init__(
32+
self,
33+
in_channels: int,
34+
out_channels: int,
35+
use_batchnorm: Union[bool, str, None] = True,
36+
use_norm: Union[bool, str, Dict[str, Any]] = True,
37+
):
2638
super().__init__()
2739

2840
self.block = nn.Sequential(
@@ -31,6 +43,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
3143
in_channels // 4,
3244
kernel_size=1,
3345
use_batchnorm=use_batchnorm,
46+
use_norm=use_norm,
3447
),
3548
TransposeX2(
3649
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
@@ -40,6 +53,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr
4053
out_channels,
4154
kernel_size=1,
4255
use_batchnorm=use_batchnorm,
56+
use_norm=use_norm,
4357
),
4458
)
4559

@@ -58,7 +72,8 @@ def __init__(
5872
encoder_channels: List[int],
5973
prefinal_channels: int = 32,
6074
n_blocks: int = 5,
61-
use_batchnorm: bool = True,
75+
use_batchnorm: Union[bool, str, None] = True,
76+
use_norm: Union[bool, str, Dict[str, Any]] = True,
6277
):
6378
super().__init__()
6479

@@ -71,7 +86,12 @@ def __init__(
7186

7287
self.blocks = nn.ModuleList(
7388
[
74-
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm)
89+
DecoderBlock(
90+
channels[i],
91+
channels[i + 1],
92+
use_batchnorm=use_batchnorm,
93+
use_norm=use_norm,
94+
)
7595
for i in range(n_blocks)
7696
]
7797
)

segmentation_models_pytorch/decoders/linknet/model.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Dict, Optional, Union
22

33
from segmentation_models_pytorch.base import (
44
ClassificationHead,
@@ -29,9 +29,27 @@ class Linknet(SegmentationModel):
2929
Default is 5
3030
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
3131
other pretrained weights (see table with available weights for each encoder_name)
32-
decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
32+
decoder_use_batchnorm: (**Deprecated**) If **True**, BatchNorm2d layer between Conv2D and Activation layers
3333
is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
3434
Available options are **True, False, "inplace"**
35+
36+
**Note:** Deprecated, prefer using `decoder_use_norm` and set this to None.
37+
decoder_use_norm: Specifies normalization between Conv2D and activation.
38+
Accepts the following types:
39+
- **True**: Defaults to `"batchnorm"`.
40+
- **False**: No normalization (`nn.Identity`).
41+
- **str**: Specifies normalization type using default parameters. Available values:
42+
`"batchnorm"`, `"identity"`, `"layernorm"`, `"groupnorm"`, `"instancenorm"`, `"inplace"`.
43+
- **dict**: Fully customizable normalization settings. Structure:
44+
```python
45+
{"type": <norm_type>, **kwargs}
46+
```
47+
where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation.
48+
49+
**Example**:
50+
```python
51+
use_norm={"type": "groupnorm", "num_groups": 8}
52+
```
3553
in_channels: A number of input channels for the model, default is 3 (RGB images)
3654
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3755
activation: An activation function to apply after the final convolution layer.
@@ -60,7 +78,8 @@ def __init__(
6078
encoder_name: str = "resnet34",
6179
encoder_depth: int = 5,
6280
encoder_weights: Optional[str] = "imagenet",
63-
decoder_use_batchnorm: bool = True,
81+
decoder_use_batchnorm: Union[bool, str, None] = True,
82+
decoder_use_norm: Union[bool, str, Dict[str, Any]] = True,
6483
in_channels: int = 3,
6584
classes: int = 1,
6685
activation: Optional[Union[str, callable]] = None,
@@ -87,6 +106,7 @@ def __init__(
87106
n_blocks=encoder_depth,
88107
prefinal_channels=32,
89108
use_batchnorm=decoder_use_batchnorm,
109+
use_norm=decoder_use_norm,
90110
)
91111

92112
self.segmentation_head = SegmentationHead(

segmentation_models_pytorch/decoders/manet/decoder.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
13
import torch
24
import torch.nn as nn
35
import torch.nn.functional as F
46

5-
from typing import List, Optional
6-
77
from segmentation_models_pytorch.base import modules as md
88

99

@@ -49,7 +49,8 @@ def __init__(
4949
in_channels: int,
5050
skip_channels: int,
5151
out_channels: int,
52-
use_batchnorm: bool = True,
52+
use_batchnorm: Union[bool, str, None] = True,
53+
use_norm: Union[bool, str, Dict[str, Any]] = True,
5354
reduction: int = 16,
5455
):
5556
# MFABBlock is just a modified version of SE-blocks, one for skip, one for input
@@ -61,9 +62,14 @@ def __init__(
6162
kernel_size=3,
6263
padding=1,
6364
use_batchnorm=use_batchnorm,
65+
use_norm=use_norm,
6466
),
6567
md.Conv2dReLU(
66-
in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm
68+
in_channels,
69+
skip_channels,
70+
kernel_size=1,
71+
use_batchnorm=use_batchnorm,
72+
use_norm=use_norm,
6773
),
6874
)
6975
reduced_channels = max(1, skip_channels // reduction)
@@ -88,13 +94,15 @@ def __init__(
8894
kernel_size=3,
8995
padding=1,
9096
use_batchnorm=use_batchnorm,
97+
use_norm=use_norm,
9198
)
9299
self.conv2 = md.Conv2dReLU(
93100
out_channels,
94101
out_channels,
95102
kernel_size=3,
96103
padding=1,
97104
use_batchnorm=use_batchnorm,
105+
use_norm=use_norm,
98106
)
99107

100108
def forward(
@@ -119,7 +127,8 @@ def __init__(
119127
in_channels: int,
120128
skip_channels: int,
121129
out_channels: int,
122-
use_batchnorm: bool = True,
130+
use_batchnorm: Union[bool, str, None] = True,
131+
use_norm: Union[bool, str, Dict[str, Any]] = True,
123132
):
124133
super().__init__()
125134
self.conv1 = md.Conv2dReLU(
@@ -128,13 +137,15 @@ def __init__(
128137
kernel_size=3,
129138
padding=1,
130139
use_batchnorm=use_batchnorm,
140+
use_norm=use_norm,
131141
)
132142
self.conv2 = md.Conv2dReLU(
133143
out_channels,
134144
out_channels,
135145
kernel_size=3,
136146
padding=1,
137147
use_batchnorm=use_batchnorm,
148+
use_norm=use_norm,
138149
)
139150

140151
def forward(
@@ -155,7 +166,8 @@ def __init__(
155166
decoder_channels: List[int],
156167
n_blocks: int = 5,
157168
reduction: int = 16,
158-
use_batchnorm: bool = True,
169+
use_batchnorm: Union[bool, str, None] = True,
170+
use_norm: Union[bool, str, Dict[str, Any]] = True,
159171
pab_channels: int = 64,
160172
):
161173
super().__init__()
@@ -182,7 +194,9 @@ def __init__(
182194
self.center = PABBlock(head_channels, pab_channels=pab_channels)
183195

184196
# combine decoder keyword arguments
185-
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
197+
kwargs = dict(
198+
use_batchnorm=use_batchnorm, use_norm=use_norm
199+
) # no attention type here
186200
blocks = [
187201
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)
188202
if skip_ch > 0

0 commit comments

Comments
 (0)