3
3
A collection of activations fn and modules with a common interface so that they can
4
4
easily be swapped. All have an `inplace` arg even if not used.
5
5
6
- These activations are not compatible with jit scripting or ONNX export of the model, please use either
7
- the JIT or basic versions of the activations.
6
+ These activations are not compatible with jit scripting or ONNX export of the model, please use
7
+ basic versions of the activations.
8
8
9
9
Hacked together by / Copyright 2020 Ross Wightman
10
10
"""
14
14
from torch .nn import functional as F
15
15
16
16
17
- @torch .jit .script
18
- def swish_jit_fwd (x ):
17
+ def swish_fwd (x ):
19
18
return x .mul (torch .sigmoid (x ))
20
19
21
20
22
- @torch .jit .script
23
- def swish_jit_bwd (x , grad_output ):
21
+ def swish_bwd (x , grad_output ):
24
22
x_sigmoid = torch .sigmoid (x )
25
23
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid )))
26
24
27
25
28
- class SwishJitAutoFn (torch .autograd .Function ):
29
- """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
26
+ class SwishAutoFn (torch .autograd .Function ):
27
+ """ optimised Swish w/ memory-efficient checkpoint
30
28
Inspired by conversation btw Jeremy Howard & Adam Pazske
31
29
https://twitter.com/jeremyphoward/status/1188251041835315200
32
30
"""
@@ -37,123 +35,117 @@ def symbolic(g, x):
37
35
@staticmethod
38
36
def forward (ctx , x ):
39
37
ctx .save_for_backward (x )
40
- return swish_jit_fwd (x )
38
+ return swish_fwd (x )
41
39
42
40
@staticmethod
43
41
def backward (ctx , grad_output ):
44
42
x = ctx .saved_tensors [0 ]
45
- return swish_jit_bwd (x , grad_output )
43
+ return swish_bwd (x , grad_output )
46
44
47
45
48
46
def swish_me (x , inplace = False ):
49
- return SwishJitAutoFn .apply (x )
47
+ return SwishAutoFn .apply (x )
50
48
51
49
52
50
class SwishMe (nn .Module ):
53
51
def __init__ (self , inplace : bool = False ):
54
52
super (SwishMe , self ).__init__ ()
55
53
56
54
def forward (self , x ):
57
- return SwishJitAutoFn .apply (x )
55
+ return SwishAutoFn .apply (x )
58
56
59
57
60
- @torch .jit .script
61
- def mish_jit_fwd (x ):
58
+ def mish_fwd (x ):
62
59
return x .mul (torch .tanh (F .softplus (x )))
63
60
64
61
65
- @torch .jit .script
66
- def mish_jit_bwd (x , grad_output ):
62
+ def mish_bwd (x , grad_output ):
67
63
x_sigmoid = torch .sigmoid (x )
68
64
x_tanh_sp = F .softplus (x ).tanh ()
69
65
return grad_output .mul (x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp ))
70
66
71
67
72
- class MishJitAutoFn (torch .autograd .Function ):
68
+ class MishAutoFn (torch .autograd .Function ):
73
69
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
74
- A memory efficient, jit scripted variant of Mish
70
+ A memory efficient variant of Mish
75
71
"""
76
72
@staticmethod
77
73
def forward (ctx , x ):
78
74
ctx .save_for_backward (x )
79
- return mish_jit_fwd (x )
75
+ return mish_fwd (x )
80
76
81
77
@staticmethod
82
78
def backward (ctx , grad_output ):
83
79
x = ctx .saved_tensors [0 ]
84
- return mish_jit_bwd (x , grad_output )
80
+ return mish_bwd (x , grad_output )
85
81
86
82
87
83
def mish_me (x , inplace = False ):
88
- return MishJitAutoFn .apply (x )
84
+ return MishAutoFn .apply (x )
89
85
90
86
91
87
class MishMe (nn .Module ):
92
88
def __init__ (self , inplace : bool = False ):
93
89
super (MishMe , self ).__init__ ()
94
90
95
91
def forward (self , x ):
96
- return MishJitAutoFn .apply (x )
92
+ return MishAutoFn .apply (x )
97
93
98
94
99
- @torch .jit .script
100
- def hard_sigmoid_jit_fwd (x , inplace : bool = False ):
95
+ def hard_sigmoid_fwd (x , inplace : bool = False ):
101
96
return (x + 3 ).clamp (min = 0 , max = 6 ).div (6. )
102
97
103
98
104
- @torch .jit .script
105
- def hard_sigmoid_jit_bwd (x , grad_output ):
99
+ def hard_sigmoid_bwd (x , grad_output ):
106
100
m = torch .ones_like (x ) * ((x >= - 3. ) & (x <= 3. )) / 6.
107
101
return grad_output * m
108
102
109
103
110
- class HardSigmoidJitAutoFn (torch .autograd .Function ):
104
+ class HardSigmoidAutoFn (torch .autograd .Function ):
111
105
@staticmethod
112
106
def forward (ctx , x ):
113
107
ctx .save_for_backward (x )
114
- return hard_sigmoid_jit_fwd (x )
108
+ return hard_sigmoid_fwd (x )
115
109
116
110
@staticmethod
117
111
def backward (ctx , grad_output ):
118
112
x = ctx .saved_tensors [0 ]
119
- return hard_sigmoid_jit_bwd (x , grad_output )
113
+ return hard_sigmoid_bwd (x , grad_output )
120
114
121
115
122
116
def hard_sigmoid_me (x , inplace : bool = False ):
123
- return HardSigmoidJitAutoFn .apply (x )
117
+ return HardSigmoidAutoFn .apply (x )
124
118
125
119
126
120
class HardSigmoidMe (nn .Module ):
127
121
def __init__ (self , inplace : bool = False ):
128
122
super (HardSigmoidMe , self ).__init__ ()
129
123
130
124
def forward (self , x ):
131
- return HardSigmoidJitAutoFn .apply (x )
125
+ return HardSigmoidAutoFn .apply (x )
132
126
133
127
134
- @torch .jit .script
135
- def hard_swish_jit_fwd (x ):
128
+ def hard_swish_fwd (x ):
136
129
return x * (x + 3 ).clamp (min = 0 , max = 6 ).div (6. )
137
130
138
131
139
- @torch .jit .script
140
- def hard_swish_jit_bwd (x , grad_output ):
132
+ def hard_swish_bwd (x , grad_output ):
141
133
m = torch .ones_like (x ) * (x >= 3. )
142
134
m = torch .where ((x >= - 3. ) & (x <= 3. ), x / 3. + .5 , m )
143
135
return grad_output * m
144
136
145
137
146
- class HardSwishJitAutoFn (torch .autograd .Function ):
147
- """A memory efficient, jit-scripted HardSwish activation"""
138
+ class HardSwishAutoFn (torch .autograd .Function ):
139
+ """A memory efficient HardSwish activation"""
148
140
@staticmethod
149
141
def forward (ctx , x ):
150
142
ctx .save_for_backward (x )
151
- return hard_swish_jit_fwd (x )
143
+ return hard_swish_fwd (x )
152
144
153
145
@staticmethod
154
146
def backward (ctx , grad_output ):
155
147
x = ctx .saved_tensors [0 ]
156
- return hard_swish_jit_bwd (x , grad_output )
148
+ return hard_swish_bwd (x , grad_output )
157
149
158
150
@staticmethod
159
151
def symbolic (g , self ):
@@ -164,55 +156,53 @@ def symbolic(g, self):
164
156
165
157
166
158
def hard_swish_me (x , inplace = False ):
167
- return HardSwishJitAutoFn .apply (x )
159
+ return HardSwishAutoFn .apply (x )
168
160
169
161
170
162
class HardSwishMe (nn .Module ):
171
163
def __init__ (self , inplace : bool = False ):
172
164
super (HardSwishMe , self ).__init__ ()
173
165
174
166
def forward (self , x ):
175
- return HardSwishJitAutoFn .apply (x )
167
+ return HardSwishAutoFn .apply (x )
176
168
177
169
178
- @torch .jit .script
179
- def hard_mish_jit_fwd (x ):
170
+ def hard_mish_fwd (x ):
180
171
return 0.5 * x * (x + 2 ).clamp (min = 0 , max = 2 )
181
172
182
173
183
- @torch .jit .script
184
- def hard_mish_jit_bwd (x , grad_output ):
174
+ def hard_mish_bwd (x , grad_output ):
185
175
m = torch .ones_like (x ) * (x >= - 2. )
186
176
m = torch .where ((x >= - 2. ) & (x <= 0. ), x + 1. , m )
187
177
return grad_output * m
188
178
189
179
190
- class HardMishJitAutoFn (torch .autograd .Function ):
191
- """ A memory efficient, jit scripted variant of Hard Mish
180
+ class HardMishAutoFn (torch .autograd .Function ):
181
+ """ A memory efficient variant of Hard Mish
192
182
Experimental, based on notes by Mish author Diganta Misra at
193
183
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
194
184
"""
195
185
@staticmethod
196
186
def forward (ctx , x ):
197
187
ctx .save_for_backward (x )
198
- return hard_mish_jit_fwd (x )
188
+ return hard_mish_fwd (x )
199
189
200
190
@staticmethod
201
191
def backward (ctx , grad_output ):
202
192
x = ctx .saved_tensors [0 ]
203
- return hard_mish_jit_bwd (x , grad_output )
193
+ return hard_mish_bwd (x , grad_output )
204
194
205
195
206
196
def hard_mish_me (x , inplace : bool = False ):
207
- return HardMishJitAutoFn .apply (x )
197
+ return HardMishAutoFn .apply (x )
208
198
209
199
210
200
class HardMishMe (nn .Module ):
211
201
def __init__ (self , inplace : bool = False ):
212
202
super (HardMishMe , self ).__init__ ()
213
203
214
204
def forward (self , x ):
215
- return HardMishJitAutoFn .apply (x )
205
+ return HardMishAutoFn .apply (x )
216
206
217
207
218
208
0 commit comments