1
+ """ Adafactor Optimizer
2
+
3
+ Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
4
+
5
+ Original header/copyright below.
6
+
7
+ """
8
+ # Copyright (c) Facebook, Inc. and its affiliates.
9
+ #
10
+ # This source code is licensed under the MIT license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ import torch
13
+ import math
14
+
15
+
16
+ class Adafactor (torch .optim .Optimizer ):
17
+ """Implements Adafactor algorithm.
18
+ This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
19
+ (see https://arxiv.org/abs/1804.04235)
20
+
21
+ Note that this optimizer internally adjusts the learning rate depending on the
22
+ *scale_parameter*, *relative_step* and *warmup_init* options.
23
+
24
+ To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
25
+ `relative_step=False`.
26
+
27
+ Arguments:
28
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
29
+ lr (float, optional): external learning rate (default: None)
30
+ eps (tuple[float, float]): regularization constants for square gradient
31
+ and parameter scale respectively (default: (1e-30, 1e-3))
32
+ clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
33
+ decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
34
+ beta1 (float): coefficient used for computing running averages of gradient (default: None)
35
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
36
+ scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
37
+ relative_step (bool): if True, time-dependent learning rate is computed
38
+ instead of external learning rate (default: True)
39
+ warmup_init (bool): time-dependent learning rate computation depends on
40
+ whether warm-up initialization is being used (default: False)
41
+ """
42
+
43
+ def __init__ (self , params , lr = None , eps = 1e-30 , eps_scale = 1e-3 , clip_threshold = 1.0 ,
44
+ decay_rate = - 0.8 , betas = None , weight_decay = 0.0 , scale_parameter = True , warmup_init = False ):
45
+ relative_step = lr is None
46
+ if warmup_init and not relative_step :
47
+ raise ValueError ('warmup_init requires relative_step=True' )
48
+
49
+ beta1 = None if betas is None else betas [0 ] # make it compat with standard betas arg
50
+ defaults = dict (lr = lr , eps = eps , eps_scale = eps_scale , clip_threshold = clip_threshold , decay_rate = decay_rate ,
51
+ beta1 = beta1 , weight_decay = weight_decay , scale_parameter = scale_parameter ,
52
+ relative_step = relative_step , warmup_init = warmup_init )
53
+ super (Adafactor , self ).__init__ (params , defaults )
54
+
55
+ @staticmethod
56
+ def _get_lr (param_group , param_state ):
57
+ if param_group ['relative_step' ]:
58
+ min_step = 1e-6 * param_state ['step' ] if param_group ['warmup_init' ] else 1e-2
59
+ lr_t = min (min_step , 1.0 / math .sqrt (param_state ['step' ]))
60
+ param_scale = 1.0
61
+ if param_group ['scale_parameter' ]:
62
+ param_scale = max (param_group ['eps_scale' ], param_state ['RMS' ])
63
+ param_group ['lr' ] = lr_t * param_scale
64
+ return param_group ['lr' ]
65
+
66
+ @staticmethod
67
+ def _get_options (param_group , param_shape ):
68
+ factored = len (param_shape ) >= 2
69
+ use_first_moment = param_group ['beta1' ] is not None
70
+ return factored , use_first_moment
71
+
72
+ @staticmethod
73
+ def _rms (tensor ):
74
+ return tensor .norm (2 ) / (tensor .numel () ** 0.5 )
75
+
76
+ def _approx_sq_grad (self , exp_avg_sq_row , exp_avg_sq_col ):
77
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ().unsqueeze (- 1 )
78
+ c_factor = exp_avg_sq_col .unsqueeze (- 2 ).rsqrt ()
79
+ return torch .mul (r_factor , c_factor )
80
+
81
+ def step (self , closure = None ):
82
+ """Performs a single optimization step.
83
+ Arguments:
84
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
85
+ """
86
+ loss = None
87
+ if closure is not None :
88
+ loss = closure ()
89
+
90
+ for group in self .param_groups :
91
+ for p in group ['params' ]:
92
+ if p .grad is None :
93
+ continue
94
+ grad = p .grad .data
95
+ if grad .dtype in {torch .float16 , torch .bfloat16 }:
96
+ grad = grad .float ()
97
+ if grad .is_sparse :
98
+ raise RuntimeError ('Adafactor does not support sparse gradients.' )
99
+
100
+ state = self .state [p ]
101
+ grad_shape = grad .shape
102
+
103
+ factored , use_first_moment = self ._get_options (group , grad_shape )
104
+ # State Initialization
105
+ if len (state ) == 0 :
106
+ state ['step' ] = 0
107
+
108
+ if use_first_moment :
109
+ # Exponential moving average of gradient values
110
+ state ['exp_avg' ] = torch .zeros_like (grad )
111
+ if factored :
112
+ state ['exp_avg_sq_row' ] = torch .zeros (grad_shape [:- 1 ]).to (grad )
113
+ state ['exp_avg_sq_col' ] = torch .zeros (grad_shape [:- 2 ] + grad_shape [- 1 :]).to (grad )
114
+ else :
115
+ state ['exp_avg_sq' ] = torch .zeros_like (grad )
116
+
117
+ state ['RMS' ] = 0
118
+ else :
119
+ if use_first_moment :
120
+ state ['exp_avg' ] = state ['exp_avg' ].to (grad )
121
+ if factored :
122
+ state ['exp_avg_sq_row' ] = state ['exp_avg_sq_row' ].to (grad )
123
+ state ['exp_avg_sq_col' ] = state ['exp_avg_sq_col' ].to (grad )
124
+ else :
125
+ state ['exp_avg_sq' ] = state ['exp_avg_sq' ].to (grad )
126
+
127
+ p_data_fp32 = p .data
128
+ if p .data .dtype in {torch .float16 , torch .bfloat16 }:
129
+ p_data_fp32 = p_data_fp32 .float ()
130
+
131
+ state ['step' ] += 1
132
+ state ['RMS' ] = self ._rms (p_data_fp32 )
133
+ lr_t = self ._get_lr (group , state )
134
+
135
+ beta2t = 1.0 - math .pow (state ['step' ], group ['decay_rate' ])
136
+ update = grad ** 2 + group ['eps' ]
137
+ if factored :
138
+ exp_avg_sq_row = state ['exp_avg_sq_row' ]
139
+ exp_avg_sq_col = state ['exp_avg_sq_col' ]
140
+
141
+ exp_avg_sq_row .mul_ (beta2t ).add_ (1.0 - beta2t , update .mean (dim = - 1 ))
142
+ exp_avg_sq_col .mul_ (beta2t ).add_ (1.0 - beta2t , update .mean (dim = - 2 ))
143
+ #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+
144
+ #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
145
+
146
+ # Approximation of exponential moving average of square of gradient
147
+ update = self ._approx_sq_grad (exp_avg_sq_row , exp_avg_sq_col )
148
+ update .mul_ (grad )
149
+ else :
150
+ exp_avg_sq = state ['exp_avg_sq' ]
151
+
152
+ exp_avg_sq .mul_ (beta2t ).add_ (1.0 - beta2t , update )
153
+ #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+
154
+ update = exp_avg_sq .rsqrt ().mul_ (grad )
155
+
156
+ update .div_ ((self ._rms (update ) / group ['clip_threshold' ]).clamp_ (min = 1.0 ))
157
+ update .mul_ (lr_t )
158
+
159
+ if use_first_moment :
160
+ exp_avg = state ['exp_avg' ]
161
+ exp_avg .mul_ (group ["beta1" ]).add_ (1 - group ["beta1" ], update )
162
+ #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+
163
+ update = exp_avg
164
+
165
+ if group ['weight_decay' ] != 0 :
166
+ p_data_fp32 .add_ (- group ["weight_decay" ] * lr_t , p_data_fp32 )
167
+ #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+
168
+
169
+ p_data_fp32 .add_ (- update )
170
+
171
+ if p .data .dtype in {torch .float16 , torch .bfloat16 }:
172
+ p .data .copy_ (p_data_fp32 )
173
+
174
+ return loss
0 commit comments