Skip to content

Commit 0654eb3

Browse files
committed
Add model based wd skip support. Improve cross version compat of optimizer factory. Fix huggingface#247
1 parent 9ab8aea commit 0654eb3

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

timm/optim/optim_factory.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
4141
opt_lower = args.opt.lower()
4242
weight_decay = args.weight_decay
4343
if weight_decay and filter_bias_and_bn:
44-
parameters = add_weight_decay(model, weight_decay)
44+
skip = {}
45+
if hasattr(model, 'no_weight_decay'):
46+
skip = model.no_weight_decay
47+
parameters = add_weight_decay(model, weight_decay, skip)
4548
weight_decay = 0.
4649
else:
4750
parameters = model.parameters()
@@ -50,9 +53,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
5053
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
5154

5255
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
53-
if args.opt_eps is not None:
56+
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
5457
opt_args['eps'] = args.opt_eps
55-
if args.opt_betas is not None:
58+
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
5659
opt_args['betas'] = args.opt_betas
5760

5861
opt_split = opt_lower.split('_')

0 commit comments

Comments
 (0)