|
390 | 390 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
391 | 391 | group.add_argument('--log-wandb', action='store_true', default=False,
|
392 | 392 | help='log training and validation metrics to wandb')
|
| 393 | +group.add_argument('--wandb-project', default=None, type=str, |
| 394 | + help='wandb project name') |
393 | 395 | group.add_argument('--wandb-tags', default=[], type=str, nargs='+',
|
394 | 396 | help='wandb tags')
|
395 | 397 | group.add_argument('--wandb-resume-id', default='', type=str, metavar='ID',
|
@@ -832,20 +834,21 @@ def main():
|
832 | 834 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
833 | 835 | f.write(args_text)
|
834 | 836 |
|
835 |
| - if utils.is_primary(args) and args.log_wandb: |
836 |
| - if has_wandb: |
837 |
| - assert not args.wandb_resume_id or args.resume |
838 |
| - wandb.init( |
839 |
| - project=args.experiment, |
840 |
| - config=args, |
841 |
| - tags=args.wandb_tags, |
842 |
| - resume='must' if args.wandb_resume_id else None, |
843 |
| - id=args.wandb_resume_id if args.wandb_resume_id else None, |
844 |
| - ) |
845 |
| - else: |
846 |
| - _logger.warning( |
847 |
| - "You've requested to log metrics to wandb but package not found. " |
848 |
| - "Metrics not being logged to wandb, try `pip install wandb`") |
| 837 | + if args.log_wandb: |
| 838 | + if has_wandb: |
| 839 | + assert not args.wandb_resume_id or args.resume |
| 840 | + wandb.init( |
| 841 | + project=args.wandb_project, |
| 842 | + name=exp_name, |
| 843 | + config=args, |
| 844 | + tags=args.wandb_tags, |
| 845 | + resume="must" if args.wandb_resume_id else None, |
| 846 | + id=args.wandb_resume_id if args.wandb_resume_id else None, |
| 847 | + ) |
| 848 | + else: |
| 849 | + _logger.warning( |
| 850 | + "You've requested to log metrics to wandb but package not found. " |
| 851 | + "Metrics not being logged to wandb, try `pip install wandb`") |
849 | 852 |
|
850 | 853 | # setup learning rate schedule and starting epoch
|
851 | 854 | updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
|
0 commit comments