193
193
help = 'disable fast prefetcher' )
194
194
parser .add_argument ('--output' , default = '' , type = str , metavar = 'PATH' ,
195
195
help = 'path to output folder (default: none, current dir)' )
196
- parser .add_argument ('--eval-metric' , default = 'prec1 ' , type = str , metavar = 'EVAL_METRIC' ,
197
- help = 'Best metric (default: "prec1 "' )
196
+ parser .add_argument ('--eval-metric' , default = 'top1 ' , type = str , metavar = 'EVAL_METRIC' ,
197
+ help = 'Best metric (default: "top1 "' )
198
198
parser .add_argument ('--tta' , type = int , default = 0 , metavar = 'N' ,
199
199
help = 'Test/inference time augmentation (oversampling) factor. 0=None (default: 0)' )
200
200
parser .add_argument ("--local_rank" , default = 0 , type = int )
@@ -596,8 +596,8 @@ def train_epoch(
596
596
def validate (model , loader , loss_fn , args , log_suffix = '' ):
597
597
batch_time_m = AverageMeter ()
598
598
losses_m = AverageMeter ()
599
- prec1_m = AverageMeter ()
600
- prec5_m = AverageMeter ()
599
+ top1_m = AverageMeter ()
600
+ top5_m = AverageMeter ()
601
601
602
602
model .eval ()
603
603
@@ -621,20 +621,20 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
621
621
target = target [0 :target .size (0 ):reduce_factor ]
622
622
623
623
loss = loss_fn (output , target )
624
- prec1 , prec5 = accuracy (output , target , topk = (1 , 5 ))
624
+ acc1 , acc5 = accuracy (output , target , topk = (1 , 5 ))
625
625
626
626
if args .distributed :
627
627
reduced_loss = reduce_tensor (loss .data , args .world_size )
628
- prec1 = reduce_tensor (prec1 , args .world_size )
629
- prec5 = reduce_tensor (prec5 , args .world_size )
628
+ acc1 = reduce_tensor (acc1 , args .world_size )
629
+ acc5 = reduce_tensor (acc5 , args .world_size )
630
630
else :
631
631
reduced_loss = loss .data
632
632
633
633
torch .cuda .synchronize ()
634
634
635
635
losses_m .update (reduced_loss .item (), input .size (0 ))
636
- prec1_m .update (prec1 .item (), output .size (0 ))
637
- prec5_m .update (prec5 .item (), output .size (0 ))
636
+ top1_m .update (acc1 .item (), output .size (0 ))
637
+ top5_m .update (acc5 .item (), output .size (0 ))
638
638
639
639
batch_time_m .update (time .time () - end )
640
640
end = time .time ()
@@ -644,13 +644,12 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
644
644
'{0}: [{1:>4d}/{2}] '
645
645
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
646
646
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
647
- 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
648
- 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})' .format (
649
- log_name , batch_idx , last_idx ,
650
- batch_time = batch_time_m , loss = losses_m ,
651
- top1 = prec1_m , top5 = prec5_m ))
647
+ 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
648
+ 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})' .format (
649
+ log_name , batch_idx , last_idx , batch_time = batch_time_m ,
650
+ loss = losses_m , top1 = top1_m , top5 = top5_m ))
652
651
653
- metrics = OrderedDict ([('loss' , losses_m .avg ), ('prec1 ' , prec1_m .avg ), ('prec5 ' , prec5_m .avg )])
652
+ metrics = OrderedDict ([('loss' , losses_m .avg ), ('top1 ' , top1_m .avg ), ('top5 ' , top5_m .avg )])
654
653
655
654
return metrics
656
655
0 commit comments