Skip to content

Commit e8ca458

Browse files
committed
More models in sotabench, more control over sotabench run, dataset filename extraction consistency
1 parent 9c40653 commit e8ca458

File tree

7 files changed

+103
-36
lines changed

7 files changed

+103
-36
lines changed

inference.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
(args.model, sum([m.numel() for m in model.parameters()])))
7474

7575
config = resolve_data_config(vars(args), model=model)
76-
model, test_time_pool = apply_test_time_pool(model, config, args)
76+
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config)
7777

7878
if args.num_gpu > 1:
7979
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
@@ -115,9 +115,8 @@ def main():
115115
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
116116

117117
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
118-
filenames = loader.dataset.filenames()
118+
filenames = loader.dataset.filenames(basename=True)
119119
for filename, label in zip(filenames, topk_ids):
120-
filename = os.path.basename(filename)
121120
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
122121
filename, label[0], label[1], label[2], label[3], label[4]))
123122

sotabench.py

+68-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import torch
2-
from torchbench.image_classification import ImageNet
2+
from sotabencheval.image_classification import ImageNetEvaluator
3+
from sotabencheval.utils import is_server
34
from timm import create_model
4-
from timm.data import resolve_data_config, create_transform
5-
from timm.models import TestTimePoolHead
5+
from timm.data import resolve_data_config, create_loader, DatasetTar
6+
from timm.models import apply_test_time_pool
7+
from tqdm import tqdm
68
import os
79

810
NUM_GPU = 1
@@ -148,6 +150,10 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
148150
_entry('ese_vovnet19b_dw', 'VoVNet-19-DW-V2', '1911.06667'),
149151
_entry('ese_vovnet39b', 'VoVNet-39-V2', '1911.06667'),
150152

153+
_entry('cspresnet50', 'CSPResNet-50', '1911.11929'),
154+
_entry('cspresnext50', 'CSPResNeXt-50', '1911.11929'),
155+
_entry('cspdarknet53', 'CSPDarkNet-53', '1911.11929'),
156+
151157
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
152158
model_desc='Ported from official Google AI Tensorflow weights'),
153159
_entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',
@@ -448,34 +454,81 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
448454
_entry('regnety_160', 'RegNetY-16GF', '2003.13678'),
449455
_entry('regnety_320', 'RegNetY-32GF', '2003.13678', batch_size=BATCH_SIZE // 2),
450456

457+
_entry('rexnet_100', 'ReXNet-1.0x', '2007.00992'),
458+
_entry('rexnet_130', 'ReXNet-1.3x', '2007.00992'),
459+
_entry('rexnet_150', 'ReXNet-1.5x', '2007.00992'),
460+
_entry('rexnet_200', 'ReXNet-2.0x', '2007.00992'),
451461
]
452462

463+
if is_server():
464+
DATA_ROOT = './.data/vision/imagenet'
465+
else:
466+
# local settings
467+
DATA_ROOT = './'
468+
DATA_FILENAME = 'ILSVRC2012_img_val.tar'
469+
TAR_PATH = os.path.join(DATA_ROOT, DATA_FILENAME)
470+
453471
for m in model_list:
454472
model_name = m['model']
455473
# create model from name
456474
model = create_model(model_name, pretrained=True)
457475
param_count = sum([m.numel() for m in model.parameters()])
458476
print('Model %s, %s created. Param count: %d' % (model_name, m['paper_model_name'], param_count))
459477

478+
dataset = DatasetTar(TAR_PATH)
479+
filenames = [os.path.splitext(f)[0] for f in dataset.filenames()]
480+
460481
# get appropriate transform for model's default pretrained config
461482
data_config = resolve_data_config(m['args'], model=model, verbose=True)
483+
test_time_pool = False
462484
if m['ttp']:
463-
model = TestTimePoolHead(model, model.default_cfg['pool_size'])
485+
model, test_time_pool = apply_test_time_pool(model, data_config)
464486
data_config['crop_pct'] = 1.0
465-
input_transform = create_transform(**data_config)
466487

467-
# Run the benchmark
468-
ImageNet.benchmark(
469-
model=model,
470-
model_description=m.get('model_description', None),
471-
paper_model_name=m['paper_model_name'],
488+
batch_size = m['batch_size']
489+
loader = create_loader(
490+
dataset,
491+
input_size=data_config['input_size'],
492+
batch_size=batch_size,
493+
use_prefetcher=True,
494+
interpolation=data_config['interpolation'],
495+
mean=data_config['mean'],
496+
std=data_config['std'],
497+
num_workers=6,
498+
crop_pct=data_config['crop_pct'],
499+
pin_memory=True)
500+
501+
evaluator = ImageNetEvaluator(
502+
root=DATA_ROOT,
503+
model_name=m['paper_model_name'],
472504
paper_arxiv_id=m['paper_arxiv_id'],
473-
input_transform=input_transform,
474-
batch_size=m['batch_size'],
475-
num_gpu=NUM_GPU,
476-
data_root=os.environ.get('IMAGENET_DIR', './.data/vision/imagenet')
505+
model_description=m.get('model_description', None),
477506
)
478-
507+
model.cuda()
508+
model.eval()
509+
with torch.no_grad():
510+
# warmup
511+
input = torch.randn((batch_size,) + data_config['input_size']).cuda()
512+
model(input)
513+
514+
bar = tqdm(desc="Evaluation", mininterval=5, total=50000)
515+
evaluator.reset_time()
516+
sample_count = 0
517+
for input, target in loader:
518+
output = model(input)
519+
num_samples = len(output)
520+
image_ids = [filenames[i] for i in range(sample_count, sample_count + num_samples)]
521+
output = output.cpu().numpy()
522+
evaluator.add(dict(zip(image_ids, list(output))))
523+
sample_count += num_samples
524+
bar.update(num_samples)
525+
bar.close()
526+
527+
evaluator.save()
528+
for k, v in evaluator.results.items():
529+
print(k, v)
530+
for k, v in evaluator.speed_mem_metrics.items():
531+
print(k, v)
479532
torch.cuda.empty_cache()
480533

481534

sotabench_setup.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@ source /workspace/venv/bin/activate
33

44
pip install -r requirements-sotabench.txt
55

6+
apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev
67
pip uninstall -y pillow
78
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
89

910
# FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work
1011
apt-get install wget
11-
wget https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet
12-
wget https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet
12+
#wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet
13+
wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet

timm/data/dataset.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,21 @@ def __getitem__(self, index):
9494
def __len__(self):
9595
return len(self.samples)
9696

97-
def filenames(self, indices=[], basename=False):
98-
if indices:
99-
if basename:
100-
return [os.path.basename(self.samples[i][0]) for i in indices]
101-
else:
102-
return [self.samples[i][0] for i in indices]
103-
else:
104-
if basename:
105-
return [os.path.basename(x[0]) for x in self.samples]
106-
else:
107-
return [x[0] for x in self.samples]
97+
def filename(self, index, basename=False, absolute=False):
98+
filename = self.samples[index][0]
99+
if basename:
100+
filename = os.path.basename(filename)
101+
elif not absolute:
102+
filename = os.path.relpath(filename, self.root)
103+
return filename
104+
105+
def filenames(self, basename=False, absolute=False):
106+
fn = lambda x: x
107+
if basename:
108+
fn = os.path.basename
109+
elif not absolute:
110+
fn = lambda x: os.path.relpath(x, self.root)
111+
return [fn(x[0]) for x in self.samples]
108112

109113

110114
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
@@ -160,6 +164,16 @@ def __getitem__(self, index):
160164
def __len__(self):
161165
return len(self.samples)
162166

167+
def filename(self, index, basename=False):
168+
filename = self.samples[index][0].name
169+
if basename:
170+
filename = os.path.basename(filename)
171+
return filename
172+
173+
def filenames(self, basename=False):
174+
fn = os.path.basename if basename else lambda x: x
175+
return [fn(x[0].name) for x in self.samples]
176+
163177

164178
class AugMixDataset(torch.utils.data.Dataset):
165179
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""

timm/models/layers/test_time_pool.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,12 @@ def forward(self, x):
3636
return x.view(x.size(0), -1)
3737

3838

39-
def apply_test_time_pool(model, config, args):
39+
def apply_test_time_pool(model, config):
4040
test_time_pool = False
4141
if not hasattr(model, 'default_cfg') or not model.default_cfg:
4242
return model, False
43-
if not args.no_test_pool and \
44-
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
45-
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
43+
if (config['input_size'][-1] > model.default_cfg['input_size'][-1] and
44+
config['input_size'][-2] > model.default_cfg['input_size'][-2]):
4645
_logger.info('Target input size %s > pretrained default %s, using test time pooling' %
4746
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
4847
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])

timm/models/rexnet.py

+1
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_strid
166166
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
167167
super(ReXNetV1, self).__init__()
168168
self.drop_rate = drop_rate
169+
self.num_classes = num_classes
169170

170171
assert output_stride == 32 # FIXME support dilation
171172
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32

validate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def validate(args):
139139
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
140140

141141
data_config = resolve_data_config(vars(args), model=model)
142-
model, test_time_pool = apply_test_time_pool(model, data_config, args)
142+
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, data_config)
143143

144144
if args.torchscript:
145145
torch.jit.optimized_execution(True)

0 commit comments

Comments
 (0)