@@ -136,48 +136,32 @@ def _update_lr_and_momentum_in_metrics_dict(
136
136
momentum_list : List [Union [float , None ]],
137
137
):
138
138
"""Update learning rate and momentum in metrics_dict
139
+ (consider only 0-th param group)
139
140
140
141
Args:
141
142
metrics_dict: batch_metrics or epoch_metrics
142
143
lr_list: lr for each param group
143
144
momentum_list: momentum for each param group
144
145
145
146
"""
146
- if len (lr_list ) == 1 :
147
- lr = lr_list [0 ]
148
- momentum = momentum_list [0 ]
147
+ # todo: consider saving lr and momentum for all param groups ?
148
+ lr = lr_list [0 ]
149
+ momentum = momentum_list [0 ]
150
+
151
+ lr_key = (
152
+ f"lr/{ self .scheduler_key } "
153
+ if self .scheduler_key is not None
154
+ else "lr"
155
+ )
156
+ metrics_dict [lr_key ] = lr
149
157
150
- lr_key = (
151
- f"lr/{ self .scheduler_key } "
158
+ if momentum is not None :
159
+ momentum_key = (
160
+ f"momentum/{ self .scheduler_key } "
152
161
if self .scheduler_key is not None
153
- else "lr "
162
+ else "momentum "
154
163
)
155
- metrics_dict [lr_key ] = lr
156
-
157
- if momentum is not None :
158
- momentum_key = (
159
- f"momentum/{ self .scheduler_key } "
160
- if self .scheduler_key is not None
161
- else "momentum"
162
- )
163
- metrics_dict [momentum_key ] = momentum
164
-
165
- else :
166
- for i , (lr , momentum ) in enumerate (zip (lr_list , momentum_list )):
167
- lr_key = (
168
- f"lr/{ self .scheduler_key } /group_{ i } "
169
- if self .scheduler_key is not None
170
- else f"lr/group_{ i } "
171
- )
172
- metrics_dict [lr_key ] = lr
173
-
174
- if momentum is not None :
175
- momentum_key = (
176
- f"momentum/{ self .scheduler_key } /group_{ i } "
177
- if self .scheduler_key is not None
178
- else f"momentum/group_{ i } "
179
- )
180
- metrics_dict [momentum_key ] = momentum
164
+ metrics_dict [momentum_key ] = momentum
181
165
182
166
def step_batch (self , runner : "IRunner" ) -> None :
183
167
"""Perform scheduler step and update batch metrics
0 commit comments