23
23
from contextlib import suppress
24
24
from datetime import datetime
25
25
26
- import wandb
27
-
28
26
import torch
29
27
import torch .nn as nn
30
28
import torchvision .utils
54
52
except AttributeError :
55
53
pass
56
54
55
+ try :
56
+ import wandb
57
+ has_wandb = True
58
+ except ModuleNotFoundError :
59
+ has_wandb = False
60
+
57
61
torch .backends .cudnn .benchmark = True
58
62
_logger = logging .getLogger ('train' )
59
63
274
278
parser .add_argument ('--torchscript' , dest = 'torchscript' , action = 'store_true' ,
275
279
help = 'convert model torchscript for inference' )
276
280
parser .add_argument ('--log-wandb' , action = 'store_true' , default = False ,
277
- help = 'use wandb for training and validation logs ' )
281
+ help = 'log training and validation metrics to wandb ' )
278
282
279
283
280
284
def _parse_args ():
@@ -299,8 +303,12 @@ def main():
299
303
args , args_text = _parse_args ()
300
304
301
305
if args .log_wandb :
302
- wandb .init (project = args .experiment , config = args )
303
-
306
+ if has_wandb :
307
+ wandb .init (project = args .experiment , config = args )
308
+ else :
309
+ _logger .warning ("You've requested to log metrics to wandb but package not found. "
310
+ "Metrics not being logged to wandb, try `pip install wandb`" )
311
+
304
312
args .prefetcher = not args .no_prefetcher
305
313
args .distributed = False
306
314
if 'WORLD_SIZE' in os .environ :
@@ -600,7 +608,7 @@ def main():
600
608
601
609
update_summary (
602
610
epoch , train_metrics , eval_metrics , os .path .join (output_dir , 'summary.csv' ),
603
- write_header = best_metric is None , log_wandb = args .log_wandb )
611
+ write_header = best_metric is None , log_wandb = args .log_wandb and has_wandb )
604
612
605
613
if saver is not None :
606
614
# save proper checkpoint with eval metric
0 commit comments