Skip to content

Commit c4fb98f

Browse files
authored
Merge pull request #2398 from huggingface/caojiaolong-main
Merging wandb project name chages w/ addition
2 parents 2d0ac6f + c173886 commit c4fb98f

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

train.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@
390390
help='use the multi-epochs-loader to save time at the beginning of every epoch')
391391
group.add_argument('--log-wandb', action='store_true', default=False,
392392
help='log training and validation metrics to wandb')
393+
group.add_argument('--wandb-project', default=None, type=str,
394+
help='wandb project name')
393395
group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
394396
help='wandb tags')
395397
group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
@@ -832,20 +834,21 @@ def main():
832834
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
833835
f.write(args_text)
834836

835-
if utils.is_primary(args) and args.log_wandb:
836-
if has_wandb:
837-
assert not args.wandb_resume_id or args.resume
838-
wandb.init(
839-
project=args.experiment,
840-
config=args,
841-
tags=args.wandb_tags,
842-
resume='must' if args.wandb_resume_id else None,
843-
id=args.wandb_resume_id if args.wandb_resume_id else None,
844-
)
845-
else:
846-
_logger.warning(
847-
"You've requested to log metrics to wandb but package not found. "
848-
"Metrics not being logged to wandb, try `pip install wandb`")
837+
if args.log_wandb:
838+
if has_wandb:
839+
assert not args.wandb_resume_id or args.resume
840+
wandb.init(
841+
project=args.wandb_project,
842+
name=exp_name,
843+
config=args,
844+
tags=args.wandb_tags,
845+
resume="must" if args.wandb_resume_id else None,
846+
id=args.wandb_resume_id if args.wandb_resume_id else None,
847+
)
848+
else:
849+
_logger.warning(
850+
"You've requested to log metrics to wandb but package not found. "
851+
"Metrics not being logged to wandb, try `pip install wandb`")
849852

850853
# setup learning rate schedule and starting epoch
851854
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps

0 commit comments

Comments
 (0)