Skip to content

Commit e861b74

Browse files
committed
Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way.
1 parent add3fb8 commit e861b74

File tree

15 files changed

+775
-360
lines changed

15 files changed

+775
-360
lines changed

benchmark.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from timm.layers import set_fast_norm
2323
from timm.models import create_model, is_model, list_models
2424
from timm.optim import create_optimizer_v2
25-
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
25+
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
2626

2727
has_apex = False
2828
try:
@@ -108,12 +108,15 @@
108108
help='Enable gradient checkpointing through model blocks/stages')
109109
parser.add_argument('--amp', action='store_true', default=False,
110110
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
111+
parser.add_argument('--amp-dtype', default='float16', type=str,
112+
help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
111113
parser.add_argument('--precision', default='float32', type=str,
112114
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
113115
parser.add_argument('--fuser', default='', type=str,
114116
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
115117
parser.add_argument('--fast-norm', default=False, action='store_true',
116118
help='enable experimental fast-norm')
119+
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
117120

118121
# codegen (model compilation) options
119122
scripting_group = parser.add_mutually_exclusive_group()
@@ -124,7 +127,6 @@
124127
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
125128
help="Enable AOT Autograd optimization.")
126129

127-
128130
# train optimizer parameters
129131
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
130132
help='Optimizer (default: "sgd"')
@@ -168,19 +170,21 @@ def count_params(model: nn.Module):
168170

169171

170172
def resolve_precision(precision: str):
171-
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
172-
use_amp = False
173+
assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
174+
amp_dtype = None # amp disabled
173175
model_dtype = torch.float32
174176
data_dtype = torch.float32
175177
if precision == 'amp':
176-
use_amp = True
178+
amp_dtype = torch.float16
179+
elif precision == 'amp_bfloat16':
180+
amp_dtype = torch.bfloat16
177181
elif precision == 'float16':
178182
model_dtype = torch.float16
179183
data_dtype = torch.float16
180184
elif precision == 'bfloat16':
181185
model_dtype = torch.bfloat16
182186
data_dtype = torch.bfloat16
183-
return use_amp, model_dtype, data_dtype
187+
return amp_dtype, model_dtype, data_dtype
184188

185189

186190
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
@@ -228,9 +232,12 @@ def __init__(
228232
self.model_name = model_name
229233
self.detail = detail
230234
self.device = device
231-
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
235+
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
232236
self.channels_last = kwargs.pop('channels_last', False)
233-
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
237+
if self.amp_dtype is not None:
238+
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
239+
else:
240+
self.amp_autocast = suppress
234241

235242
if fuser:
236243
set_jit_fuser(fuser)
@@ -243,6 +250,7 @@ def __init__(
243250
drop_rate=kwargs.pop('drop', 0.),
244251
drop_path_rate=kwargs.pop('drop_path', None),
245252
drop_block_rate=kwargs.pop('drop_block', None),
253+
**kwargs.pop('model_kwargs', {}),
246254
)
247255
self.model.to(
248256
device=self.device,
@@ -560,7 +568,7 @@ def _try_run(
560568
def benchmark(args):
561569
if args.amp:
562570
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
563-
args.precision = 'amp'
571+
args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
564572
_logger.info(f'Benchmarking in {args.precision} precision. '
565573
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
566574
f'torchscript {"enabled" if args.torchscript else "disabled"}')

inference.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from timm.data import create_dataset, create_loader, resolve_data_config
2121
from timm.layers import apply_test_time_pool
2222
from timm.models import create_model
23-
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
23+
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
2424

2525
try:
2626
from apex import amp
@@ -72,6 +72,8 @@
7272
metavar='N', help='mini-batch size (default: 256)')
7373
parser.add_argument('--img-size', default=None, type=int,
7474
metavar='N', help='Input image dimension, uses model default if empty')
75+
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
76+
help='Image input channels (default: None => 3)')
7577
parser.add_argument('--input-size', default=None, nargs=3, type=int,
7678
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
7779
parser.add_argument('--use-train-size', action='store_true', default=False,
@@ -110,6 +112,7 @@
110112
help='lower precision AMP dtype (default: float16)')
111113
parser.add_argument('--fuser', default='', type=str,
112114
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
115+
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
113116

114117
scripting_group = parser.add_mutually_exclusive_group()
115118
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@@ -170,12 +173,19 @@ def main():
170173
set_jit_fuser(args.fuser)
171174

172175
# create model
176+
in_chans = 3
177+
if args.in_chans is not None:
178+
in_chans = args.in_chans
179+
elif args.input_size is not None:
180+
in_chans = args.input_size[0]
181+
173182
model = create_model(
174183
args.model,
175184
num_classes=args.num_classes,
176-
in_chans=3,
185+
in_chans=in_chans,
177186
pretrained=args.pretrained,
178187
checkpoint_path=args.checkpoint,
188+
**args.model_kwargs,
179189
)
180190
if args.num_classes is None:
181191
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'

timm/models/byobnet.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
218218

219219

220220
def interleave_blocks(
221-
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
221+
types: Tuple[str, str], d,
222+
every: Union[int, List[int]] = 1,
223+
first: bool = False,
224+
**kwargs,
222225
) -> Tuple[ByoBlockCfg]:
223226
""" interleave 2 block types in stack
224227
"""
@@ -1587,15 +1590,32 @@ def __init__(
15871590
in_chans=3,
15881591
global_pool='avg',
15891592
output_stride=32,
1590-
zero_init_last=True,
15911593
img_size=None,
15921594
drop_rate=0.,
15931595
drop_path_rate=0.,
1596+
zero_init_last=True,
1597+
**kwargs,
15941598
):
1599+
"""
1600+
1601+
Args:
1602+
cfg (ByoModelCfg): Model architecture configuration
1603+
num_classes (int): Number of classifier classes (default: 1000)
1604+
in_chans (int): Number of input channels (default: 3)
1605+
global_pool (str): Global pooling type (default: 'avg')
1606+
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
1607+
img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn)
1608+
drop_rate (float): Dropout rate (default: 0.)
1609+
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
1610+
zero_init_last (bool): Zero-init last weight of residual path
1611+
kwargs (dict): Extra kwargs overlayed onto cfg
1612+
"""
15951613
super().__init__()
15961614
self.num_classes = num_classes
15971615
self.drop_rate = drop_rate
15981616
self.grad_checkpointing = False
1617+
1618+
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
15991619
layers = get_layer_fns(cfg)
16001620
if cfg.fixed_input_size:
16011621
assert img_size is not None, 'img_size argument is required for fixed input size model'

timm/models/convnext.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(
167167
conv_bias=conv_bias,
168168
use_grn=use_grn,
169169
act_layer=act_layer,
170-
norm_layer=norm_layer if conv_mlp else norm_layer_cl
170+
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
171171
))
172172
in_chs = out_chs
173173
self.blocks = nn.Sequential(*stage_blocks)
@@ -184,16 +184,6 @@ def forward(self, x):
184184
class ConvNeXt(nn.Module):
185185
r""" ConvNeXt
186186
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
187-
188-
Args:
189-
in_chans (int): Number of input image channels. Default: 3
190-
num_classes (int): Number of classes for classification head. Default: 1000
191-
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
192-
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
193-
drop_rate (float): Head dropout rate
194-
drop_path_rate (float): Stochastic depth rate. Default: 0.
195-
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
196-
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
197187
"""
198188

199189
def __init__(
@@ -218,6 +208,28 @@ def __init__(
218208
drop_rate=0.,
219209
drop_path_rate=0.,
220210
):
211+
"""
212+
Args:
213+
in_chans (int): Number of input image channels (default: 3)
214+
num_classes (int): Number of classes for classification head (default: 1000)
215+
global_pool (str): Global pooling type (default: 'avg')
216+
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
217+
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
218+
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
219+
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
220+
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
221+
stem_type (str): Type of stem (default: 'patch')
222+
patch_size (int): Stem patch size for patch stem (default: 4)
223+
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
224+
head_norm_first (bool): Apply normalization before global pool + head (default: False)
225+
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
226+
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
227+
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
228+
act_layer (Union[str, nn.Module]): Activation Layer
229+
norm_layer (Union[str, nn.Module]): Normalization Layer
230+
drop_rate (float): Head dropout rate (default: 0.)
231+
drop_path_rate (float): Stochastic depth rate (default: 0.)
232+
"""
221233
super().__init__()
222234
assert output_stride in (8, 16, 32)
223235
kernel_sizes = to_ntuple(4)(kernel_sizes)
@@ -279,7 +291,7 @@ def __init__(
279291
use_grn=use_grn,
280292
act_layer=act_layer,
281293
norm_layer=norm_layer,
282-
norm_layer_cl=norm_layer_cl
294+
norm_layer_cl=norm_layer_cl,
283295
))
284296
prev_chs = out_chs
285297
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2

timm/models/cspnet.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
1313
Hacked together by / Copyright 2020 Ross Wightman
1414
"""
15-
from dataclasses import dataclass, asdict
15+
from dataclasses import dataclass, asdict, replace
1616
from functools import partial
1717
from typing import Any, Dict, Optional, Tuple, Union
1818

@@ -518,7 +518,7 @@ def __init__(
518518
cross_linear=False,
519519
block_dpr=None,
520520
block_fn=BottleneckBlock,
521-
**block_kwargs
521+
**block_kwargs,
522522
):
523523
super(CrossStage, self).__init__()
524524
first_dilation = first_dilation or dilation
@@ -558,7 +558,7 @@ def __init__(
558558
bottle_ratio=bottle_ratio,
559559
groups=groups,
560560
drop_path=block_dpr[i] if block_dpr is not None else 0.,
561-
**block_kwargs
561+
**block_kwargs,
562562
))
563563
prev_chs = block_out_chs
564564

@@ -597,7 +597,7 @@ def __init__(
597597
cross_linear=False,
598598
block_dpr=None,
599599
block_fn=BottleneckBlock,
600-
**block_kwargs
600+
**block_kwargs,
601601
):
602602
super(CrossStage3, self).__init__()
603603
first_dilation = first_dilation or dilation
@@ -635,7 +635,7 @@ def __init__(
635635
bottle_ratio=bottle_ratio,
636636
groups=groups,
637637
drop_path=block_dpr[i] if block_dpr is not None else 0.,
638-
**block_kwargs
638+
**block_kwargs,
639639
))
640640
prev_chs = block_out_chs
641641

@@ -668,7 +668,7 @@ def __init__(
668668
avg_down=False,
669669
block_fn=BottleneckBlock,
670670
block_dpr=None,
671-
**block_kwargs
671+
**block_kwargs,
672672
):
673673
super(DarkStage, self).__init__()
674674
first_dilation = first_dilation or dilation
@@ -715,7 +715,7 @@ def create_csp_stem(
715715
padding='',
716716
act_layer=nn.ReLU,
717717
norm_layer=nn.BatchNorm2d,
718-
aa_layer=None
718+
aa_layer=None,
719719
):
720720
stem = nn.Sequential()
721721
feature_info = []
@@ -738,7 +738,7 @@ def create_csp_stem(
738738
stride=conv_stride,
739739
padding=padding if i == 0 else '',
740740
act_layer=act_layer,
741-
norm_layer=norm_layer
741+
norm_layer=norm_layer,
742742
))
743743
stem_stride *= conv_stride
744744
prev_chs = chs
@@ -800,7 +800,7 @@ def create_csp_stages(
800800
cfg: CspModelCfg,
801801
drop_path_rate: float,
802802
output_stride: int,
803-
stem_feat: Dict[str, Any]
803+
stem_feat: Dict[str, Any],
804804
):
805805
cfg_dict = asdict(cfg.stages)
806806
num_stages = len(cfg.stages.depth)
@@ -868,12 +868,27 @@ def __init__(
868868
global_pool='avg',
869869
drop_rate=0.,
870870
drop_path_rate=0.,
871-
zero_init_last=True
871+
zero_init_last=True,
872+
**kwargs,
872873
):
874+
"""
875+
Args:
876+
cfg (CspModelCfg): Model architecture configuration
877+
in_chans (int): Number of input channels (default: 3)
878+
num_classes (int): Number of classifier classes (default: 1000)
879+
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
880+
global_pool (str): Global pooling type (default: 'avg')
881+
drop_rate (float): Dropout rate (default: 0.)
882+
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
883+
zero_init_last (bool): Zero-init last weight of residual path
884+
kwargs (dict): Extra kwargs overlayed onto cfg
885+
"""
873886
super().__init__()
874887
self.num_classes = num_classes
875888
self.drop_rate = drop_rate
876889
assert output_stride in (8, 16, 32)
890+
891+
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
877892
layer_args = dict(
878893
act_layer=cfg.act_layer,
879894
norm_layer=cfg.norm_layer,

0 commit comments

Comments
 (0)