Skip to content

Commit 9c467e3

Browse files
authored
SchedulerCallback: save lr and momentum only for 0-th param group (#1028)
* SchedulerCallback: save lr and momentum only for 0-th param group * Codestyle
1 parent b061b78 commit 9c467e3

File tree

1 file changed

+16
-32
lines changed

1 file changed

+16
-32
lines changed

catalyst/callbacks/scheduler.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -136,48 +136,32 @@ def _update_lr_and_momentum_in_metrics_dict(
136136
momentum_list: List[Union[float, None]],
137137
):
138138
"""Update learning rate and momentum in metrics_dict
139+
(consider only 0-th param group)
139140
140141
Args:
141142
metrics_dict: batch_metrics or epoch_metrics
142143
lr_list: lr for each param group
143144
momentum_list: momentum for each param group
144145
145146
"""
146-
if len(lr_list) == 1:
147-
lr = lr_list[0]
148-
momentum = momentum_list[0]
147+
# todo: consider saving lr and momentum for all param groups ?
148+
lr = lr_list[0]
149+
momentum = momentum_list[0]
150+
151+
lr_key = (
152+
f"lr/{self.scheduler_key}"
153+
if self.scheduler_key is not None
154+
else "lr"
155+
)
156+
metrics_dict[lr_key] = lr
149157

150-
lr_key = (
151-
f"lr/{self.scheduler_key}"
158+
if momentum is not None:
159+
momentum_key = (
160+
f"momentum/{self.scheduler_key}"
152161
if self.scheduler_key is not None
153-
else "lr"
162+
else "momentum"
154163
)
155-
metrics_dict[lr_key] = lr
156-
157-
if momentum is not None:
158-
momentum_key = (
159-
f"momentum/{self.scheduler_key}"
160-
if self.scheduler_key is not None
161-
else "momentum"
162-
)
163-
metrics_dict[momentum_key] = momentum
164-
165-
else:
166-
for i, (lr, momentum) in enumerate(zip(lr_list, momentum_list)):
167-
lr_key = (
168-
f"lr/{self.scheduler_key}/group_{i}"
169-
if self.scheduler_key is not None
170-
else f"lr/group_{i}"
171-
)
172-
metrics_dict[lr_key] = lr
173-
174-
if momentum is not None:
175-
momentum_key = (
176-
f"momentum/{self.scheduler_key}/group_{i}"
177-
if self.scheduler_key is not None
178-
else f"momentum/group_{i}"
179-
)
180-
metrics_dict[momentum_key] = momentum
164+
metrics_dict[momentum_key] = momentum
181165

182166
def step_batch(self, runner: "IRunner") -> None:
183167
"""Perform scheduler step and update batch metrics

0 commit comments

Comments
 (0)