Skip to content

Commit f54897c

Browse files
committed
make wandb not required but rather optional as huggingface_hub
1 parent f13f750 commit f54897c

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
torch>=1.4.0
22
torchvision>=0.5.0
3-
pyyaml
4-
wandb
3+
pyyaml

timm/utils/summary.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
"""
55
import csv
66
import os
7-
import wandb
87
from collections import OrderedDict
9-
8+
try:
9+
import wandb
10+
except ImportError:
11+
pass
1012

1113
def get_outdir(path, *paths, inc=False):
1214
outdir = os.path.join(path, *paths)
@@ -28,8 +30,6 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
2830
rowd = OrderedDict(epoch=epoch)
2931
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
3032
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
31-
if log_wandb:
32-
wandb.log(rowd)
3333
with open(filename, mode='a') as cf:
3434
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
3535
if write_header: # first iteration (epoch == 1 can't be used)

train.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from contextlib import suppress
2424
from datetime import datetime
2525

26-
import wandb
27-
2826
import torch
2927
import torch.nn as nn
3028
import torchvision.utils
@@ -54,6 +52,12 @@
5452
except AttributeError:
5553
pass
5654

55+
try:
56+
import wandb
57+
has_wandb = True
58+
except ModuleNotFoundError:
59+
has_wandb = False
60+
5761
torch.backends.cudnn.benchmark = True
5862
_logger = logging.getLogger('train')
5963

@@ -274,7 +278,7 @@
274278
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
275279
help='convert model torchscript for inference')
276280
parser.add_argument('--log-wandb', action='store_true', default=False,
277-
help='use wandb for training and validation logs')
281+
help='log training and validation metrics to wandb')
278282

279283

280284
def _parse_args():
@@ -299,8 +303,12 @@ def main():
299303
args, args_text = _parse_args()
300304

301305
if args.log_wandb:
302-
wandb.init(project=args.experiment, config=args)
303-
306+
if has_wandb:
307+
wandb.init(project=args.experiment, config=args)
308+
else:
309+
_logger.warning("You've requested to log metrics to wandb but package not found. "
310+
"Metrics not being logged to wandb, try `pip install wandb`")
311+
304312
args.prefetcher = not args.no_prefetcher
305313
args.distributed = False
306314
if 'WORLD_SIZE' in os.environ:
@@ -600,7 +608,7 @@ def main():
600608

601609
update_summary(
602610
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
603-
write_header=best_metric is None, log_wandb=args.log_wandb)
611+
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
604612

605613
if saver is not None:
606614
# save proper checkpoint with eval metric

0 commit comments

Comments
 (0)