Skip to content

Commit 80078c4

Browse files
committed
Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support.
1 parent fcb6258 commit 80078c4

File tree

7 files changed

+406
-63
lines changed

7 files changed

+406
-63
lines changed

Diff for: sotabench_setup.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ pip install -r requirements-sotabench.txt
77
apt-get update
88
apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev
99
pip uninstall -y pillow
10-
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
10+
CFLAGS="${CFLAGS} -mavx2" pip install -U --no-cache-dir --force-reinstall --no-binary :all:--compile https://github.com/mrT23/pillow-simd/zipball/simd/7.0.x
11+
#CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
1112

1213
# FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work
1314
apt-get install wget

Diff for: timm/optim/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from .nadam import Nadam
2-
from .rmsprop_tf import RMSpropTF
1+
from .adamp import AdamP
32
from .adamw import AdamW
4-
from .radam import RAdam
3+
from .adafactor import Adafactor
4+
from .adahessian import Adahessian
5+
from .lookahead import Lookahead
6+
from .nadam import Nadam
57
from .novograd import NovoGrad
68
from .nvnovograd import NvNovoGrad
7-
from .lookahead import Lookahead
8-
from .adamp import AdamP
9+
from .radam import RAdam
10+
from .rmsprop_tf import RMSpropTF
911
from .sgdp import SGDP
10-
from .optim_factory import create_optimizer
12+
13+
from .optim_factory import create_optimizer

Diff for: timm/optim/adafactor.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
""" Adafactor Optimizer
2+
3+
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4+
5+
Original header/copyright below.
6+
7+
"""
8+
# Copyright (c) Facebook, Inc. and its affiliates.
9+
#
10+
# This source code is licensed under the MIT license found in the
11+
# LICENSE file in the root directory of this source tree.
12+
import torch
13+
import math
14+
15+
16+
class Adafactor(torch.optim.Optimizer):
17+
"""Implements Adafactor algorithm.
18+
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19+
(see https://arxiv.org/abs/1804.04235)
20+
21+
Note that this optimizer internally adjusts the learning rate depending on the
22+
*scale_parameter*, *relative_step* and *warmup_init* options.
23+
24+
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25+
`relative_step=False`.
26+
27+
Arguments:
28+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29+
lr (float, optional): external learning rate (default: None)
30+
eps (tuple[float, float]): regularization constants for square gradient
31+
and parameter scale respectively (default: (1e-30, 1e-3))
32+
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33+
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34+
beta1 (float): coefficient used for computing running averages of gradient (default: None)
35+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36+
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37+
relative_step (bool): if True, time-dependent learning rate is computed
38+
instead of external learning rate (default: True)
39+
warmup_init (bool): time-dependent learning rate computation depends on
40+
whether warm-up initialization is being used (default: False)
41+
"""
42+
43+
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
44+
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
45+
relative_step = lr is None
46+
if warmup_init and not relative_step:
47+
raise ValueError('warmup_init requires relative_step=True')
48+
49+
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
50+
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
51+
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
52+
relative_step=relative_step, warmup_init=warmup_init)
53+
super(Adafactor, self).__init__(params, defaults)
54+
55+
@staticmethod
56+
def _get_lr(param_group, param_state):
57+
if param_group['relative_step']:
58+
min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
59+
lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
60+
param_scale = 1.0
61+
if param_group['scale_parameter']:
62+
param_scale = max(param_group['eps_scale'], param_state['RMS'])
63+
param_group['lr'] = lr_t * param_scale
64+
return param_group['lr']
65+
66+
@staticmethod
67+
def _get_options(param_group, param_shape):
68+
factored = len(param_shape) >= 2
69+
use_first_moment = param_group['beta1'] is not None
70+
return factored, use_first_moment
71+
72+
@staticmethod
73+
def _rms(tensor):
74+
return tensor.norm(2) / (tensor.numel() ** 0.5)
75+
76+
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
77+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
78+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
79+
return torch.mul(r_factor, c_factor)
80+
81+
def step(self, closure=None):
82+
"""Performs a single optimization step.
83+
Arguments:
84+
closure (callable, optional): A closure that reevaluates the model and returns the loss.
85+
"""
86+
loss = None
87+
if closure is not None:
88+
loss = closure()
89+
90+
for group in self.param_groups:
91+
for p in group['params']:
92+
if p.grad is None:
93+
continue
94+
grad = p.grad.data
95+
if grad.dtype in {torch.float16, torch.bfloat16}:
96+
grad = grad.float()
97+
if grad.is_sparse:
98+
raise RuntimeError('Adafactor does not support sparse gradients.')
99+
100+
state = self.state[p]
101+
grad_shape = grad.shape
102+
103+
factored, use_first_moment = self._get_options(group, grad_shape)
104+
# State Initialization
105+
if len(state) == 0:
106+
state['step'] = 0
107+
108+
if use_first_moment:
109+
# Exponential moving average of gradient values
110+
state['exp_avg'] = torch.zeros_like(grad)
111+
if factored:
112+
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
113+
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
114+
else:
115+
state['exp_avg_sq'] = torch.zeros_like(grad)
116+
117+
state['RMS'] = 0
118+
else:
119+
if use_first_moment:
120+
state['exp_avg'] = state['exp_avg'].to(grad)
121+
if factored:
122+
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
123+
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
124+
else:
125+
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
126+
127+
p_data_fp32 = p.data
128+
if p.data.dtype in {torch.float16, torch.bfloat16}:
129+
p_data_fp32 = p_data_fp32.float()
130+
131+
state['step'] += 1
132+
state['RMS'] = self._rms(p_data_fp32)
133+
lr_t = self._get_lr(group, state)
134+
135+
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
136+
update = grad ** 2 + group['eps']
137+
if factored:
138+
exp_avg_sq_row = state['exp_avg_sq_row']
139+
exp_avg_sq_col = state['exp_avg_sq_col']
140+
141+
exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
142+
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))
143+
#exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
144+
#exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
145+
146+
# Approximation of exponential moving average of square of gradient
147+
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
148+
update.mul_(grad)
149+
else:
150+
exp_avg_sq = state['exp_avg_sq']
151+
152+
exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update)
153+
#exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
154+
update = exp_avg_sq.rsqrt().mul_(grad)
155+
156+
update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
157+
update.mul_(lr_t)
158+
159+
if use_first_moment:
160+
exp_avg = state['exp_avg']
161+
exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update)
162+
#exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
163+
update = exp_avg
164+
165+
if group['weight_decay'] != 0:
166+
p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32)
167+
#p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
168+
169+
p_data_fp32.add_(-update)
170+
171+
if p.data.dtype in {torch.float16, torch.bfloat16}:
172+
p.data.copy_(p_data_fp32)
173+
174+
return loss

Diff for: timm/optim/adahessian.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
""" AdaHessian Optimizer
2+
3+
Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
4+
Originally licensed MIT, Copyright 2020, David Samuel
5+
"""
6+
import torch
7+
8+
9+
class Adahessian(torch.optim.Optimizer):
10+
"""
11+
Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
12+
13+
Arguments:
14+
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
15+
lr (float, optional): learning rate (default: 0.1)
16+
betas ((float, float), optional): coefficients used for computing running averages of gradient and the
17+
squared hessian trace (default: (0.9, 0.999))
18+
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
19+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
20+
hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
21+
update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
22+
(to save time) (default: 1)
23+
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
24+
"""
25+
26+
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
27+
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
28+
if not 0.0 <= lr:
29+
raise ValueError(f"Invalid learning rate: {lr}")
30+
if not 0.0 <= eps:
31+
raise ValueError(f"Invalid epsilon value: {eps}")
32+
if not 0.0 <= betas[0] < 1.0:
33+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
34+
if not 0.0 <= betas[1] < 1.0:
35+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
36+
if not 0.0 <= hessian_power <= 1.0:
37+
raise ValueError(f"Invalid Hessian power value: {hessian_power}")
38+
39+
self.n_samples = n_samples
40+
self.update_each = update_each
41+
self.avg_conv_kernel = avg_conv_kernel
42+
43+
# use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
44+
self.seed = 2147483647
45+
self.generator = torch.Generator().manual_seed(self.seed)
46+
47+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
48+
super(Adahessian, self).__init__(params, defaults)
49+
50+
for p in self.get_params():
51+
p.hess = 0.0
52+
self.state[p]["hessian step"] = 0
53+
54+
@property
55+
def is_second_order(self):
56+
return True
57+
58+
def get_params(self):
59+
"""
60+
Gets all parameters in all param_groups with gradients
61+
"""
62+
63+
return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
64+
65+
def zero_hessian(self):
66+
"""
67+
Zeros out the accumalated hessian traces.
68+
"""
69+
70+
for p in self.get_params():
71+
if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
72+
p.hess.zero_()
73+
74+
@torch.no_grad()
75+
def set_hessian(self):
76+
"""
77+
Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
78+
"""
79+
80+
params = []
81+
for p in filter(lambda p: p.grad is not None, self.get_params()):
82+
if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
83+
params.append(p)
84+
self.state[p]["hessian step"] += 1
85+
86+
if len(params) == 0:
87+
return
88+
89+
if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
90+
self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
91+
92+
grads = [p.grad for p in params]
93+
94+
for i in range(self.n_samples):
95+
# Rademacher distribution {-1.0, 1.0}
96+
zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
97+
h_zs = torch.autograd.grad(
98+
grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
99+
for h_z, z, p in zip(h_zs, zs, params):
100+
p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
101+
102+
@torch.no_grad()
103+
def step(self, closure=None):
104+
"""
105+
Performs a single optimization step.
106+
Arguments:
107+
closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
108+
"""
109+
110+
loss = None
111+
if closure is not None:
112+
loss = closure()
113+
114+
self.zero_hessian()
115+
self.set_hessian()
116+
117+
for group in self.param_groups:
118+
for p in group['params']:
119+
if p.grad is None or p.hess is None:
120+
continue
121+
122+
if self.avg_conv_kernel and p.dim() == 4:
123+
p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
124+
125+
# Perform correct stepweight decay as in AdamW
126+
p.mul_(1 - group['lr'] * group['weight_decay'])
127+
128+
state = self.state[p]
129+
130+
# State initialization
131+
if len(state) == 1:
132+
state['step'] = 0
133+
# Exponential moving average of gradient values
134+
state['exp_avg'] = torch.zeros_like(p)
135+
# Exponential moving average of Hessian diagonal square values
136+
state['exp_hessian_diag_sq'] = torch.zeros_like(p)
137+
138+
exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
139+
beta1, beta2 = group['betas']
140+
state['step'] += 1
141+
142+
# Decay the first and second moment running average coefficient
143+
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
144+
exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
145+
146+
bias_correction1 = 1 - beta1 ** state['step']
147+
bias_correction2 = 1 - beta2 ** state['step']
148+
149+
k = group['hessian_power']
150+
denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
151+
152+
# make update
153+
step_size = group['lr'] / bias_correction1
154+
p.addcdiv_(exp_avg, denom, value=-step_size)
155+
156+
return loss

0 commit comments

Comments
 (0)