-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Deprecate use_batchnorm in favor of generalized use_norm parameter #1095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e26adcd
d65001b
1b16b25
10d496a
467057a
1ae11c3
be22951
7c88361
1255ee0
b0d4113
799f8f4
cb10389
4b6792f
22ea569
ce59ffa
05d6d7a
2856bc5
846e112
e8852c9
c8c114a
1f422a1
b53525e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import warnings | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
@@ -16,11 +18,53 @@ | |
padding=0, | ||
stride=1, | ||
use_batchnorm=True, | ||
use_norm="batchnorm", | ||
): | ||
if use_batchnorm == "inplace" and InPlaceABN is None: | ||
if use_batchnorm is not None: | ||
warnings.warn( | ||
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm", | ||
DeprecationWarning, | ||
) | ||
if use_batchnorm is True: | ||
use_norm = {"type": "batchnorm"} | ||
elif use_batchnorm is False: | ||
use_norm = {"type": "identity"} | ||
elif use_batchnorm == "inplace": | ||
use_norm = { | ||
"type": "inplace", | ||
"activation": "leaky_relu", | ||
"activation_param": 0.0, | ||
} | ||
else: | ||
raise ValueError("Unrecognized value for use_batchnorm") | ||
|
||
if isinstance(use_norm, str): | ||
norm_str = use_norm.lower() | ||
if norm_str == "inplace": | ||
use_norm = { | ||
"type": "inplace", | ||
"activation": "leaky_relu", | ||
"activation_param": 0.0, | ||
} | ||
elif norm_str in ( | ||
"batchnorm", | ||
"identity", | ||
"layernorm", | ||
"groupnorm", | ||
"instancenorm", | ||
): | ||
use_norm = {"type": norm_str} | ||
else: | ||
raise ValueError("Unrecognized normalization type string provided") | ||
elif isinstance(use_norm, bool): | ||
use_norm = {"type": "batchnorm" if use_norm else "identity"} | ||
elif not isinstance(use_norm, dict): | ||
raise ValueError("use_norm must be a dictionary, boolean, or string") | ||
|
||
if use_norm["type"] == "inplace" and InPlaceABN is None: | ||
raise RuntimeError( | ||
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " | ||
+ "To install see: https://github.com/mapillary/inplace_abn" | ||
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. " | ||
"To install see: https://github.com/mapillary/inplace_abn" | ||
) | ||
|
||
conv = nn.Conv2d( | ||
|
@@ -29,21 +73,30 @@ | |
kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
bias=not (use_batchnorm), | ||
bias=use_norm["type"] != "inplace", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can initialize norm first and use a separate varaible
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeap simpler |
||
) | ||
relu = nn.ReLU(inplace=True) | ||
|
||
if use_batchnorm == "inplace": | ||
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) | ||
relu = nn.Identity() | ||
|
||
elif use_batchnorm and use_batchnorm != "inplace": | ||
bn = nn.BatchNorm2d(out_channels) | ||
norm_type = use_norm["type"] | ||
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"} | ||
|
||
if norm_type == "inplace": | ||
norm = InPlaceABN(out_channels, **extra_kwargs) | ||
relu = nn.Identity() | ||
elif norm_type == "batchnorm": | ||
norm = nn.BatchNorm2d(out_channels, **extra_kwargs) | ||
elif norm_type == "identity": | ||
norm = nn.Identity() | ||
elif norm_type == "layernorm": | ||
norm = nn.LayerNorm(out_channels, **extra_kwargs) | ||
elif norm_type == "groupnorm": | ||
norm = nn.GroupNorm(out_channels, **extra_kwargs) | ||
elif norm_type == "instancenorm": | ||
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs) | ||
else: | ||
bn = nn.Identity() | ||
raise ValueError(f"Unrecognized normalization type: {norm_type}") | ||
|
||
super(Conv2dReLU, self).__init__(conv, bn, relu) | ||
super(Conv2dReLU, self).__init__(conv, norm, relu) | ||
|
||
|
||
class SCSEModule(nn.Module): | ||
|
@@ -127,3 +180,9 @@ | |
|
||
def forward(self, x): | ||
return self.attention(x) | ||
|
||
|
||
if __name__ == "__main__": | ||
print(Conv2dReLU(3, 12, 4)) | ||
print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"})) | ||
print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3})) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be removed, instead, it would be nice to add a test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeap, added some test around Conv2dReLu |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,18 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from typing import List, Optional | ||
from typing import Any, Dict, List, Optional, Union | ||
from segmentation_models_pytorch.base import modules | ||
|
||
|
||
class TransposeX2(nn.Sequential): | ||
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
use_batchnorm: Union[bool, str, None] = True, | ||
use_norm: Union[bool, str, Dict[str, Any]] = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should have only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made a proposition |
||
): | ||
super().__init__() | ||
layers = [ | ||
nn.ConvTranspose2d( | ||
|
@@ -15,14 +21,20 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr | |
nn.ReLU(inplace=True), | ||
] | ||
|
||
if use_batchnorm: | ||
if use_batchnorm or use_norm: | ||
layers.insert(1, nn.BatchNorm2d(out_channels)) | ||
|
||
super().__init__(*layers) | ||
|
||
|
||
class DecoderBlock(nn.Module): | ||
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
use_batchnorm: Union[bool, str, None] = True, | ||
use_norm: Union[bool, str, Dict[str, Any]] = True, | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, let's move it to the model, and leave only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made a proposition |
||
super().__init__() | ||
|
||
self.block = nn.Sequential( | ||
|
@@ -31,6 +43,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr | |
in_channels // 4, | ||
kernel_size=1, | ||
use_batchnorm=use_batchnorm, | ||
use_norm=use_norm, | ||
), | ||
TransposeX2( | ||
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm | ||
|
@@ -40,6 +53,7 @@ def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = Tr | |
out_channels, | ||
kernel_size=1, | ||
use_batchnorm=use_batchnorm, | ||
use_norm=use_norm, | ||
), | ||
) | ||
|
||
|
@@ -58,7 +72,8 @@ def __init__( | |
encoder_channels: List[int], | ||
prefinal_channels: int = 32, | ||
n_blocks: int = 5, | ||
use_batchnorm: bool = True, | ||
use_batchnorm: Union[bool, str, None] = True, | ||
use_norm: Union[bool, str, Dict[str, Any]] = True, | ||
): | ||
super().__init__() | ||
|
||
|
@@ -71,7 +86,12 @@ def __init__( | |
|
||
self.blocks = nn.ModuleList( | ||
[ | ||
DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) | ||
DecoderBlock( | ||
channels[i], | ||
channels[i + 1], | ||
use_batchnorm=use_batchnorm, | ||
use_norm=use_norm, | ||
) | ||
for i in range(n_blocks) | ||
] | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a separate function
get_norm_layer
which will validate input params and return norm layerThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Much simpler. Done the changes