Skip to content

Commit 477a78e

Browse files
committed
Fix optimizer factory regressin for optimizers like sgd/momentum that don't have an eps arg
1 parent 27a93e9 commit 477a78e

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

Diff for: timm/optim/optim_factory.py

+4
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
6161
opt_split = opt_lower.split('_')
6262
opt_lower = opt_split[-1]
6363
if opt_lower == 'sgd' or opt_lower == 'nesterov':
64+
del opt_args['eps']
6465
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
6566
elif opt_lower == 'momentum':
67+
del opt_args['eps']
6668
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
6769
elif opt_lower == 'adam':
6870
optimizer = optim.Adam(parameters, **opt_args)
@@ -93,8 +95,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
9395
elif opt_lower == 'nvnovograd':
9496
optimizer = NvNovoGrad(parameters, **opt_args)
9597
elif opt_lower == 'fusedsgd':
98+
del opt_args['eps']
9699
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
97100
elif opt_lower == 'fusedmomentum':
101+
del opt_args['eps']
98102
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
99103
elif opt_lower == 'fusedadam':
100104
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)

0 commit comments

Comments
 (0)