Skip to content

Commit 1ab93c9

Browse files
authored
Merge pull request huggingface#550 from amaarora/wandb
Wandb Support
2 parents aa98957 + cb9245e commit 1ab93c9

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

timm/utils/summary.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import csv
66
import os
77
from collections import OrderedDict
8-
8+
try:
9+
import wandb
10+
except ImportError:
11+
pass
912

1013
def get_outdir(path, *paths, inc=False):
1114
outdir = os.path.join(path, *paths)
@@ -23,10 +26,12 @@ def get_outdir(path, *paths, inc=False):
2326
return outdir
2427

2528

26-
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
29+
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
2730
rowd = OrderedDict(epoch=epoch)
2831
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
2932
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
33+
if log_wandb:
34+
wandb.log(rowd)
3035
with open(filename, mode='a') as cf:
3136
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
3237
if write_header: # first iteration (epoch == 1 can't be used)

train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@
5252
except AttributeError:
5353
pass
5454

55+
try:
56+
import wandb
57+
has_wandb = True
58+
except ImportError:
59+
has_wandb = False
60+
5561
torch.backends.cudnn.benchmark = True
5662
_logger = logging.getLogger('train')
5763

@@ -271,6 +277,8 @@
271277
help='use the multi-epochs-loader to save time at the beginning of every epoch')
272278
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
273279
help='convert model torchscript for inference')
280+
parser.add_argument('--log-wandb', action='store_true', default=False,
281+
help='log training and validation metrics to wandb')
274282

275283

276284
def _parse_args():
@@ -293,7 +301,14 @@ def _parse_args():
293301
def main():
294302
setup_default_logging()
295303
args, args_text = _parse_args()
296-
304+
305+
if args.log_wandb:
306+
if has_wandb:
307+
wandb.init(project=args.experiment, config=args)
308+
else:
309+
_logger.warning("You've requested to log metrics to wandb but package not found. "
310+
"Metrics not being logged to wandb, try `pip install wandb`")
311+
297312
args.prefetcher = not args.no_prefetcher
298313
args.distributed = False
299314
if 'WORLD_SIZE' in os.environ:
@@ -593,7 +608,7 @@ def main():
593608

594609
update_summary(
595610
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
596-
write_header=best_metric is None)
611+
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
597612

598613
if saver is not None:
599614
# save proper checkpoint with eval metric

0 commit comments

Comments
 (0)