2
2
3
3
Hacked together by / Copyright 2020 Ross Wightman
4
4
"""
5
+ from functools import partial
6
+
5
7
from torch import nn as nn
6
8
9
+ from .grn import GlobalResponseNorm
7
10
from .helpers import to_2tuple
8
11
9
12
10
13
class Mlp (nn .Module ):
11
14
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
12
15
"""
13
- def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .GELU , bias = True , drop = 0. ):
16
+ def __init__ (
17
+ self ,
18
+ in_features ,
19
+ hidden_features = None ,
20
+ out_features = None ,
21
+ act_layer = nn .GELU ,
22
+ bias = True ,
23
+ drop = 0. ,
24
+ use_conv = False ,
25
+ ):
14
26
super ().__init__ ()
15
27
out_features = out_features or in_features
16
28
hidden_features = hidden_features or in_features
17
29
bias = to_2tuple (bias )
18
30
drop_probs = to_2tuple (drop )
31
+ linear_layer = partial (nn .Conv2d , kernel_size = 1 ) if use_conv else nn .Linear
19
32
20
- self .fc1 = nn . Linear (in_features , hidden_features , bias = bias [0 ])
33
+ self .fc1 = linear_layer (in_features , hidden_features , bias = bias [0 ])
21
34
self .act = act_layer ()
22
35
self .drop1 = nn .Dropout (drop_probs [0 ])
23
- self .fc2 = nn . Linear (hidden_features , out_features , bias = bias [1 ])
36
+ self .fc2 = linear_layer (hidden_features , out_features , bias = bias [1 ])
24
37
self .drop2 = nn .Dropout (drop_probs [1 ])
25
38
26
39
def forward (self , x ):
@@ -36,18 +49,29 @@ class GluMlp(nn.Module):
36
49
""" MLP w/ GLU style gating
37
50
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
38
51
"""
39
- def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .Sigmoid , bias = True , drop = 0. ):
52
+ def __init__ (
53
+ self ,
54
+ in_features ,
55
+ hidden_features = None ,
56
+ out_features = None ,
57
+ act_layer = nn .Sigmoid ,
58
+ bias = True ,
59
+ drop = 0. ,
60
+ use_conv = False ,
61
+ ):
40
62
super ().__init__ ()
41
63
out_features = out_features or in_features
42
64
hidden_features = hidden_features or in_features
43
65
assert hidden_features % 2 == 0
44
66
bias = to_2tuple (bias )
45
67
drop_probs = to_2tuple (drop )
68
+ linear_layer = partial (nn .Conv2d , kernel_size = 1 ) if use_conv else nn .Linear
69
+ self .chunk_dim = 1 if use_conv else - 1
46
70
47
- self .fc1 = nn . Linear (in_features , hidden_features , bias = bias [0 ])
71
+ self .fc1 = linear_layer (in_features , hidden_features , bias = bias [0 ])
48
72
self .act = act_layer ()
49
73
self .drop1 = nn .Dropout (drop_probs [0 ])
50
- self .fc2 = nn . Linear (hidden_features // 2 , out_features , bias = bias [1 ])
74
+ self .fc2 = linear_layer (hidden_features // 2 , out_features , bias = bias [1 ])
51
75
self .drop2 = nn .Dropout (drop_probs [1 ])
52
76
53
77
def init_weights (self ):
@@ -58,7 +82,7 @@ def init_weights(self):
58
82
59
83
def forward (self , x ):
60
84
x = self .fc1 (x )
61
- x , gates = x .chunk (2 , dim = - 1 )
85
+ x , gates = x .chunk (2 , dim = self . chunk_dim )
62
86
x = x * self .act (gates )
63
87
x = self .drop1 (x )
64
88
x = self .fc2 (x )
@@ -70,8 +94,15 @@ class GatedMlp(nn.Module):
70
94
""" MLP as used in gMLP
71
95
"""
72
96
def __init__ (
73
- self , in_features , hidden_features = None , out_features = None , act_layer = nn .GELU ,
74
- gate_layer = None , bias = True , drop = 0. ):
97
+ self ,
98
+ in_features ,
99
+ hidden_features = None ,
100
+ out_features = None ,
101
+ act_layer = nn .GELU ,
102
+ gate_layer = None ,
103
+ bias = True ,
104
+ drop = 0. ,
105
+ ):
75
106
super ().__init__ ()
76
107
out_features = out_features or in_features
77
108
hidden_features = hidden_features or in_features
@@ -104,8 +135,15 @@ class ConvMlp(nn.Module):
104
135
""" MLP using 1x1 convs that keeps spatial dims
105
136
"""
106
137
def __init__ (
107
- self , in_features , hidden_features = None , out_features = None , act_layer = nn .ReLU ,
108
- norm_layer = None , bias = True , drop = 0. ):
138
+ self ,
139
+ in_features ,
140
+ hidden_features = None ,
141
+ out_features = None ,
142
+ act_layer = nn .ReLU ,
143
+ norm_layer = None ,
144
+ bias = True ,
145
+ drop = 0. ,
146
+ ):
109
147
super ().__init__ ()
110
148
out_features = out_features or in_features
111
149
hidden_features = hidden_features or in_features
@@ -124,3 +162,40 @@ def forward(self, x):
124
162
x = self .drop (x )
125
163
x = self .fc2 (x )
126
164
return x
165
+
166
+
167
+ class GlobalResponseNormMlp (nn .Module ):
168
+ """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
169
+ """
170
+ def __init__ (
171
+ self ,
172
+ in_features ,
173
+ hidden_features = None ,
174
+ out_features = None ,
175
+ act_layer = nn .GELU ,
176
+ bias = True ,
177
+ drop = 0. ,
178
+ use_conv = False ,
179
+ ):
180
+ super ().__init__ ()
181
+ out_features = out_features or in_features
182
+ hidden_features = hidden_features or in_features
183
+ bias = to_2tuple (bias )
184
+ drop_probs = to_2tuple (drop )
185
+ linear_layer = partial (nn .Conv2d , kernel_size = 1 ) if use_conv else nn .Linear
186
+
187
+ self .fc1 = linear_layer (in_features , hidden_features , bias = bias [0 ])
188
+ self .act = act_layer ()
189
+ self .drop1 = nn .Dropout (drop_probs [0 ])
190
+ self .grn = GlobalResponseNorm (hidden_features , channels_last = not use_conv )
191
+ self .fc2 = linear_layer (hidden_features , out_features , bias = bias [1 ])
192
+ self .drop2 = nn .Dropout (drop_probs [1 ])
193
+
194
+ def forward (self , x ):
195
+ x = self .fc1 (x )
196
+ x = self .act (x )
197
+ x = self .drop1 (x )
198
+ x = self .grn (x )
199
+ x = self .fc2 (x )
200
+ x = self .drop2 (x )
201
+ return x
0 commit comments