forked from qubvel-org/segmentation_models.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodules.py
178 lines (150 loc) · 5.71 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from typing import Any, Dict, Union
import torch
import torch.nn as nn
try:
from inplace_abn import InPlaceABN
except ImportError:
InPlaceABN = None
def get_norm_layer(use_norm: Union[bool, str, Dict[str, Any]], out_channels: int) -> nn.Module:
supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm")
if use_norm is True:
norm_params = {"type": "batchnorm"}
elif use_norm is False:
norm_params = {"type": "identity"}
elif use_norm == "inplace":
norm_params = {"type": "inplace", "activation": "leaky_relu", "activation_param": 0.0}
elif isinstance(use_norm, str):
norm_str = use_norm.lower()
if norm_str == "inplace":
norm_params = {
"type": "inplace",
"activation": "leaky_relu",
"activation_param": 0.0,
}
elif norm_str in (
"batchnorm",
"identity",
"layernorm",
"instancenorm",
):
norm_params = {"type": norm_str}
else:
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}")
elif isinstance(use_norm, dict):
norm_params = use_norm
else:
raise ValueError("use_norm must be a dictionary, boolean, or string. Please refer to the documentation.")
if not "type" in norm_params:
raise ValueError(f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'.")
if norm_params["type"] not in supported_norms:
raise ValueError(f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}")
norm_type = norm_params["type"]
extra_kwargs = {k: v for k, v in norm_params.items() if k != "type"}
if norm_type == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
"To install see: https://github.com/mapillary/inplace_abn"
)
elif norm_type == "inplace":
norm = InPlaceABN(out_channels, **extra_kwargs)
elif norm_type == "batchnorm":
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
elif norm_type == "identity":
norm = nn.Identity()
elif norm_type == "layernorm":
norm = nn.LayerNorm(out_channels, **extra_kwargs)
elif norm_type == "instancenorm":
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
else:
raise ValueError(f"Unrecognized normalization type: {norm_type}")
return norm
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_norm="batchnorm",
):
norm = get_norm_layer(use_norm, out_channels)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=norm._get_name() != "BatchNorm2d",
)
if norm._get_name() == "Inplace":
relu = nn.Identity()
else:
relu = nn.ReLU(inplace=True)
super(Conv2dReLU, self).__init__(conv, norm, relu)
class SCSEModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
def forward(self, x):
return x * self.cSE(x) + x * self.sSE(x)
class ArgMax(nn.Module):
def __init__(self, dim=None):
super().__init__()
self.dim = dim
def forward(self, x):
return torch.argmax(x, dim=self.dim)
class Clamp(nn.Module):
def __init__(self, min=0, max=1):
super().__init__()
self.min, self.max = min, max
def forward(self, x):
return torch.clamp(x, self.min, self.max)
class Activation(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None or name == "identity":
self.activation = nn.Identity(**params)
elif name == "sigmoid":
self.activation = nn.Sigmoid()
elif name == "softmax2d":
self.activation = nn.Softmax(dim=1, **params)
elif name == "softmax":
self.activation = nn.Softmax(**params)
elif name == "logsoftmax":
self.activation = nn.LogSoftmax(**params)
elif name == "tanh":
self.activation = nn.Tanh()
elif name == "argmax":
self.activation = ArgMax(**params)
elif name == "argmax2d":
self.activation = ArgMax(dim=1, **params)
elif name == "clamp":
self.activation = Clamp(**params)
elif callable(name):
self.activation = name(**params)
else:
raise ValueError(
f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
f"argmax/argmax2d/clamp/None; got {name}"
)
def forward(self, x):
return self.activation(x)
class Attention(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None:
self.attention = nn.Identity(**params)
elif name == "scse":
self.attention = SCSEModule(**params)
else:
raise ValueError("Attention {} is not implemented".format(name))
def forward(self, x):
return self.attention(x)