52
52
except AttributeError :
53
53
pass
54
54
55
+ try :
56
+ import wandb
57
+ has_wandb = True
58
+ except ImportError :
59
+ has_wandb = False
60
+
55
61
torch .backends .cudnn .benchmark = True
56
62
_logger = logging .getLogger ('train' )
57
63
271
277
help = 'use the multi-epochs-loader to save time at the beginning of every epoch' )
272
278
parser .add_argument ('--torchscript' , dest = 'torchscript' , action = 'store_true' ,
273
279
help = 'convert model torchscript for inference' )
280
+ parser .add_argument ('--log-wandb' , action = 'store_true' , default = False ,
281
+ help = 'log training and validation metrics to wandb' )
274
282
275
283
276
284
def _parse_args ():
@@ -293,7 +301,14 @@ def _parse_args():
293
301
def main ():
294
302
setup_default_logging ()
295
303
args , args_text = _parse_args ()
296
-
304
+
305
+ if args .log_wandb :
306
+ if has_wandb :
307
+ wandb .init (project = args .experiment , config = args )
308
+ else :
309
+ _logger .warning ("You've requested to log metrics to wandb but package not found. "
310
+ "Metrics not being logged to wandb, try `pip install wandb`" )
311
+
297
312
args .prefetcher = not args .no_prefetcher
298
313
args .distributed = False
299
314
if 'WORLD_SIZE' in os .environ :
@@ -593,7 +608,7 @@ def main():
593
608
594
609
update_summary (
595
610
epoch , train_metrics , eval_metrics , os .path .join (output_dir , 'summary.csv' ),
596
- write_header = best_metric is None )
611
+ write_header = best_metric is None , log_wandb = args . log_wandb and has_wandb )
597
612
598
613
if saver is not None :
599
614
# save proper checkpoint with eval metric
0 commit comments