Skip to content

Wandb Support #550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions timm/utils/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import csv
import os
from collections import OrderedDict

try:
import wandb
except ImportError:
pass

def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
Expand All @@ -23,10 +26,12 @@ def get_outdir(path, *paths, inc=False):
return outdir


def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if log_wandb:
wandb.log(rowd)
with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)
Expand Down
19 changes: 17 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
except AttributeError:
pass

try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False

torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')

Expand Down Expand Up @@ -271,6 +277,8 @@
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')


def _parse_args():
Expand All @@ -293,7 +301,14 @@ def _parse_args():
def main():
setup_default_logging()
args, args_text = _parse_args()


if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")

args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
Expand Down Expand Up @@ -593,7 +608,7 @@ def main():

update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None)
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)

if saver is not None:
# save proper checkpoint with eval metric
Expand Down