Skip to content

Commit e8e2d9c

Browse files
committed
Add DropPath (stochastic depth) to ReXNet and VoVNet. RegNet DropPath impl tweak and dedupe se args.
1 parent e8ca458 commit e8e2d9c

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

Diff for: timm/models/regnet.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,18 @@ class RegStage(nn.Module):
195195
"""Stage (sequence of blocks w/ the same output shape)."""
196196

197197
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
198-
block_fn=Bottleneck, se_ratio=0., drop_path_rate=None, drop_block=None):
198+
block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None):
199199
super(RegStage, self).__init__()
200200
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
201201
first_dilation = 1 if dilation in (1, 2) else 2
202202
for i in range(depth):
203203
block_stride = stride if i == 0 else 1
204204
block_in_chs = in_chs if i == 0 else out_chs
205205
block_dilation = first_dilation if i == 0 else dilation
206-
drop_path = DropPath(drop_path_rate[i]) if drop_path_rate is not None else None
206+
if drop_path_rates is not None and drop_path_rates[i] > 0.:
207+
drop_path = DropPath(drop_path_rates[i])
208+
else:
209+
drop_path = None
207210
if (block_in_chs != out_chs) or (block_stride != 1):
208211
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
209212
else:
@@ -301,7 +304,7 @@ def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_r
301304

302305
# Adjust the compatibility of ws and gws
303306
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
304-
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rate']
307+
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates']
305308
stage_params = [
306309
dict(zip(param_names, params)) for params in
307310
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,

Diff for: timm/models/rexnet.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1717
from .helpers import build_model_with_cfg
18-
from .layers import ClassifierHead, create_act_layer, ConvBnAct
18+
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
1919
from .registry import register_model
2020

2121

@@ -56,10 +56,10 @@ def make_divisible(v, divisor=8, min_value=None):
5656

5757
class SEWithNorm(nn.Module):
5858

59-
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
59+
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
6060
gate_layer='sigmoid'):
6161
super(SEWithNorm, self).__init__()
62-
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
62+
reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
6363
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
6464
self.bn = nn.BatchNorm2d(reduction_channels)
6565
self.act = act_layer(inplace=True)
@@ -76,7 +76,7 @@ def forward(self, x):
7676

7777

7878
class LinearBottleneck(nn.Module):
79-
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12, ch_div=1):
79+
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None):
8080
super(LinearBottleneck, self).__init__()
8181
self.use_shortcut = stride == 1 and in_chs <= out_chs
8282
self.in_channels = in_chs
@@ -90,10 +90,11 @@ def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12
9090
self.conv_exp = None
9191

9292
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
93-
self.se = SEWithNorm(dw_chs, reduction=se_rd, divisor=ch_div) if use_se else None
93+
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
9494
self.act_dw = nn.ReLU6()
9595

9696
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
97+
self.drop_path = drop_path
9798

9899
def feat_channels(self, exp=False):
99100
return self.conv_dw.out_channels if exp else self.out_channels
@@ -107,12 +108,14 @@ def forward(self, x):
107108
x = self.se(x)
108109
x = self.act_dw(x)
109110
x = self.conv_pwl(x)
111+
if self.drop_path is not None:
112+
x = self.drop_path(x)
110113
if self.use_shortcut:
111114
x[:, 0:self.in_channels] += shortcut
112115
return x
113116

114117

115-
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, use_se=True, ch_div=1):
118+
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1):
116119
layers = [1, 2, 2, 3, 3, 5]
117120
strides = [1, 2, 2, 2, 1, 2]
118121
layers = [ceil(element * depth_mult) for element in layers]
@@ -127,29 +130,31 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, us
127130
out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
128131
base_chs += final_chs / (depth // 3 * 1.0)
129132

130-
if use_se:
131-
use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:])
132-
else:
133-
use_ses = [False] * sum(layers[:])
133+
se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:])
134134

135-
return zip(out_chs_list, exp_ratios, strides, use_ses)
135+
return list(zip(out_chs_list, exp_ratios, strides, se_ratios))
136136

137137

138-
def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_location='bottleneck'):
138+
def _build_blocks(
139+
block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'):
139140
feat_exp = feature_location == 'expansion'
140141
feat_chs = [prev_chs]
141142
feature_info = []
142143
curr_stride = 2
143144
features = []
144-
for block_idx, (chs, exp_ratio, stride, se) in enumerate(block_cfg):
145+
num_blocks = len(block_cfg)
146+
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
145147
if stride > 1:
146148
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
147149
if block_idx > 0 and feat_exp:
148150
fname += '.act_dw'
149151
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
150152
curr_stride *= stride
153+
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
154+
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
151155
features.append(LinearBottleneck(
152-
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, use_se=se, se_rd=se_rd, ch_div=ch_div))
156+
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
157+
ch_div=ch_div, drop_path=drop_path))
153158
prev_chs = chs
154159
feat_chs += [features[-1].feat_channels(feat_exp)]
155160
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
@@ -162,8 +167,8 @@ def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_l
162167

163168
class ReXNetV1(nn.Module):
164169
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
165-
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True,
166-
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
170+
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
171+
ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'):
167172
super(ReXNetV1, self).__init__()
168173
self.drop_rate = drop_rate
169174
self.num_classes = num_classes
@@ -173,9 +178,9 @@ def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_strid
173178
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
174179
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish')
175180

176-
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, use_se, ch_div)
181+
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
177182
features, self.feature_info = _build_blocks(
178-
block_cfg, stem_chs, width_mult, se_rd, ch_div, feature_location)
183+
block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location)
179184
self.num_features = features[-1].out_channels
180185
self.features = nn.Sequential(*features)
181186

Diff for: timm/models/vovnet.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2121
from .registry import register_model
2222
from .helpers import build_model_with_cfg
23-
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, \
23+
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\
2424
create_attn, create_norm_act, get_norm_act_layer
2525

2626

@@ -179,7 +179,7 @@ def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Ten
179179
class OsaBlock(nn.Module):
180180

181181
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
182-
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU):
182+
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
183183
super(OsaBlock, self).__init__()
184184

185185
self.residual = residual
@@ -212,6 +212,8 @@ def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
212212
else:
213213
self.attn = None
214214

215+
self.drop_path = drop_path
216+
215217
def forward(self, x):
216218
output = [x]
217219
if self.conv_reduction is not None:
@@ -220,6 +222,8 @@ def forward(self, x):
220222
x = self.conv_concat(x)
221223
if self.attn is not None:
222224
x = self.attn(x)
225+
if self.drop_path is not None:
226+
x = self.drop_path(x)
223227
if self.residual:
224228
x = x + output[0]
225229
return x
@@ -228,7 +232,8 @@ def forward(self, x):
228232
class OsaStage(nn.Module):
229233

230234
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
231-
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU):
235+
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
236+
drop_path_rates=None):
232237
super(OsaStage, self).__init__()
233238

234239
if downsample:
@@ -239,10 +244,15 @@ def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, d
239244
blocks = []
240245
for i in range(block_per_stage):
241246
last_block = i == block_per_stage - 1
247+
if drop_path_rates is not None and drop_path_rates[i] > 0.:
248+
drop_path = DropPath(drop_path_rates[i])
249+
else:
250+
drop_path = None
242251
blocks += [OsaBlock(
243-
in_chs if i == 0 else out_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0,
244-
depthwise=depthwise, attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer)
252+
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
253+
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
245254
]
255+
in_chs = out_chs
246256
self.blocks = nn.Sequential(*blocks)
247257

248258
def forward(self, x):
@@ -255,7 +265,7 @@ def forward(self, x):
255265
class VovNet(nn.Module):
256266

257267
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
258-
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU):
268+
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
259269
""" VovNet (v2)
260270
"""
261271
super(VovNet, self).__init__()
@@ -284,14 +294,15 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_ra
284294
current_stride = stem_stride
285295

286296
# OSA stages
297+
stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage)
287298
in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
288299
stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
289300
stages = []
290301
for i in range(4): # num_stages
291302
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
292303
stages += [OsaStage(
293304
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
294-
downsample=downsample, **stage_args)
305+
downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
295306
]
296307
self.num_features = stage_out_chs[i]
297308
current_stride *= 2 if downsample else 1

0 commit comments

Comments
 (0)