Skip to content

Commit 98a7403

Browse files
authored
Merge pull request #1 from rwightman/master
merge master
2 parents 9b35818 + 13cf688 commit 98a7403

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+11405
-3087
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,5 @@ venv.bak/
104104
*.tar
105105
*.pth
106106
*.gz
107+
Untitled.ipynb
108+
Testing notebook.ipynb

README.md

Lines changed: 335 additions & 73 deletions
Large diffs are not rendered by default.

avg_checkpoints.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#!/usr/bin/env python
2+
""" Checkpoint Averaging Script
3+
4+
This script averages all model weights for checkpoints in specified path that match
5+
the specified filter wildcard. All checkpoints must be from the exact same model.
6+
7+
For any hope of decent results, the checkpoints should be from the same or child
8+
(via resumes) training session. This can be viewed as similar to maintaining running
9+
EMA (exponential moving average) of the model weights or performing SWA (stochastic
10+
weight averaging), but post-training.
11+
12+
Hacked together by Ross Wightman (https://github.com/rwightman)
13+
"""
14+
import torch
15+
import argparse
16+
import os
17+
import glob
18+
import hashlib
19+
from timm.models.helpers import load_state_dict
20+
21+
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
22+
parser.add_argument('--input', default='', type=str, metavar='PATH',
23+
help='path to base input folder containing checkpoints')
24+
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
25+
help='checkpoint filter (path wildcard)')
26+
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
27+
help='output filename')
28+
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
29+
help='Force not using ema version of weights (if present)')
30+
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
31+
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
32+
parser.add_argument('-n', type=int, default=10, metavar='N',
33+
help='Number of checkpoints to average')
34+
35+
36+
def checkpoint_metric(checkpoint_path):
37+
if not checkpoint_path or not os.path.isfile(checkpoint_path):
38+
return {}
39+
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
40+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
41+
metric = None
42+
if 'metric' in checkpoint:
43+
metric = checkpoint['metric']
44+
return metric
45+
46+
47+
def main():
48+
args = parser.parse_args()
49+
# by default use the EMA weights (if present)
50+
args.use_ema = not args.no_use_ema
51+
# by default sort by checkpoint metric (if present) and avg top n checkpoints
52+
args.sort = not args.no_sort
53+
54+
if os.path.exists(args.output):
55+
print("Error: Output filename ({}) already exists.".format(args.output))
56+
exit(1)
57+
58+
pattern = args.input
59+
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
60+
pattern += os.path.sep
61+
pattern += args.filter
62+
checkpoints = glob.glob(pattern, recursive=True)
63+
64+
if args.sort:
65+
checkpoint_metrics = []
66+
for c in checkpoints:
67+
metric = checkpoint_metric(c)
68+
if metric is not None:
69+
checkpoint_metrics.append((metric, c))
70+
checkpoint_metrics = list(sorted(checkpoint_metrics))
71+
checkpoint_metrics = checkpoint_metrics[-args.n:]
72+
print("Selected checkpoints:")
73+
[print(m, c) for m, c in checkpoint_metrics]
74+
avg_checkpoints = [c for m, c in checkpoint_metrics]
75+
else:
76+
avg_checkpoints = checkpoints
77+
print("Selected checkpoints:")
78+
[print(c) for c in checkpoints]
79+
80+
avg_state_dict = {}
81+
avg_counts = {}
82+
for c in avg_checkpoints:
83+
new_state_dict = load_state_dict(c, args.use_ema)
84+
if not new_state_dict:
85+
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
86+
continue
87+
88+
for k, v in new_state_dict.items():
89+
if k not in avg_state_dict:
90+
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
91+
avg_counts[k] = 1
92+
else:
93+
avg_state_dict[k] += v.to(dtype=torch.float64)
94+
avg_counts[k] += 1
95+
96+
for k, v in avg_state_dict.items():
97+
v.div_(avg_counts[k])
98+
99+
# float32 overflow seems unlikely based on weights seen to date, but who knows
100+
float32_info = torch.finfo(torch.float32)
101+
final_state_dict = {}
102+
for k, v in avg_state_dict.items():
103+
v = v.clamp(float32_info.min, float32_info.max)
104+
final_state_dict[k] = v.to(dtype=torch.float32)
105+
106+
torch.save(final_state_dict, args.output)
107+
with open(args.output, 'rb') as f:
108+
sha_hash = hashlib.sha256(f.read()).hexdigest()
109+
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
110+
111+
112+
if __name__ == '__main__':
113+
main()

clean_checkpoint.py

100644100755
Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
1+
#!/usr/bin/env python
2+
""" Checkpoint Cleaning Script
3+
4+
Takes training checkpoints with GPU tensors, optimizer state, extra dict keys, etc.
5+
and outputs a CPU tensor checkpoint with only the `state_dict` along with SHA256
6+
calculation for model zoo compatibility.
7+
8+
Hacked together by Ross Wightman (https://github.com/rwightman)
9+
"""
110
import torch
211
import argparse
312
import os
413
import hashlib
14+
import shutil
515
from collections import OrderedDict
616

7-
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
17+
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
818
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
919
help='path to latest checkpoint (default: none)')
10-
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
20+
parser.add_argument('--output', default='', type=str, metavar='PATH',
1121
help='output path')
1222
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
1323
help='use ema version of weights if present')
24+
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
25+
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
26+
27+
_TEMP_NAME = './_checkpoint.pth'
1428

1529

1630
def main():
@@ -31,19 +45,31 @@ def main():
3145
if state_dict_key in checkpoint:
3246
state_dict = checkpoint[state_dict_key]
3347
else:
34-
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
35-
exit(1)
48+
state_dict = checkpoint
3649
else:
37-
state_dict = checkpoint
50+
assert False
3851
for k, v in state_dict.items():
52+
if args.clean_aux_bn and 'aux_bn' in k:
53+
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
54+
# load with the unmodified model using BatchNorm2d.
55+
continue
3956
name = k[7:] if k.startswith('module') else k
4057
new_state_dict[name] = v
4158
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
4259

43-
torch.save(new_state_dict, args.output)
44-
with open(args.output, 'rb') as f:
60+
torch.save(new_state_dict, _TEMP_NAME)
61+
with open(_TEMP_NAME, 'rb') as f:
4562
sha_hash = hashlib.sha256(f.read()).hexdigest()
46-
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
63+
64+
if args.output:
65+
checkpoint_root, checkpoint_base = os.path.split(args.output)
66+
checkpoint_base = os.path.splitext(checkpoint_base)[0]
67+
else:
68+
checkpoint_root = ''
69+
checkpoint_base = os.path.splitext(args.checkpoint)[0]
70+
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
71+
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
72+
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
4773
else:
4874
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
4975

convert/convert_from_mxnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import mxnet as mx
66
import gluoncv
77
import torch
8-
from models.model_factory import create_model
8+
from timm import create_model
99

1010
parser = argparse.ArgumentParser(description='Convert from MXNet')
1111
parser.add_argument('--model', default='all', type=str, metavar='MODEL',

hubconf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
dependencies = ['torch']
2+
from timm.models import registry
3+
4+
globals().update(registry._model_entrypoints)

inference.py

100644100755
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""Sample PyTorch Inference script
2-
"""
1+
#!/usr/bin/env python
2+
"""PyTorch Inference Script
33
4-
from __future__ import absolute_import
5-
from __future__ import division
6-
from __future__ import print_function
4+
An example inference script that outputs top-k class ids for images in a folder into a csv.
75
6+
Hacked together by Ross Wightman (https://github.com/rwightman)
7+
"""
88
import os
99
import time
1010
import argparse
@@ -29,7 +29,7 @@
2929
help='number of data loading workers (default: 2)')
3030
parser.add_argument('-b', '--batch-size', default=256, type=int,
3131
metavar='N', help='mini-batch size (default: 256)')
32-
parser.add_argument('--img-size', default=224, type=int,
32+
parser.add_argument('--img-size', default=None, type=int,
3333
metavar='N', help='Input image dimension')
3434
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
3535
help='Override mean pixel value of dataset')

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
torch~=1.0
2-
torchvision
1+
torch>=1.2.0
2+
torchvision>=0.4.0
3+
pyyaml

results/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Validation Results
2+
3+
This folder contains validation results for the models in this collection having pretrained weights. Since the focus for this repository is currently ImageNet-1k classification, all of the results are based on datasets compatible with ImageNet-1k classes.
4+
5+
## Datasets
6+
7+
There are currently results for the ImageNet validation set and 3 additional test sets.
8+
9+
### ImageNet Validation - [`results-imagenet.csv`](results-imagenet.csv)
10+
11+
* Source: http://image-net.org/challenges/LSVRC/2012/index
12+
* Paper: "ImageNet Large Scale Visual Recognition Challenge" - https://arxiv.org/abs/1409.0575
13+
14+
The standard 50,000 image ImageNet-1k validation set. Model selection during training utilizes this validation set, so it is not a true test set. Question: Does anyone have the official ImageNet-1k test set classification labels now that challenges are done?
15+
16+
### ImageNetV2 Matched Frequency - [`results-imagenetv2-matched-frequency.csv`](results-imagenetv2-matched-frequency.csv)
17+
18+
* Source: https://github.com/modestyachts/ImageNetV2
19+
* Paper: "Do ImageNet Classifiers Generalize to ImageNet?" - https://arxiv.org/abs/1902.10811
20+
21+
An ImageNet test set of 10,000 images sampled from new images roughly 10 years after the original. Care was taken to replicate the original ImageNet curation/sampling process.
22+
23+
### ImageNet-Sketch - [`results-sketch.csv`](results-sketch.csv)
24+
25+
* Source: https://github.com/HaohanWang/ImageNet-Sketch
26+
* Paper: "Learning Robust Global Representations by Penalizing Local Predictive Power" - https://arxiv.org/abs/1905.13549
27+
28+
50,000 non photographic (or photos of such) images (sketches, doodles, mostly monochromatic) covering all 1000 ImageNet classes.
29+
30+
### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv)
31+
32+
* Source: https://github.com/hendrycks/natural-adv-examples
33+
* Paper: "Natural Adversarial Examples" - https://arxiv.org/abs/1907.07174
34+
35+
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
36+
37+
## TODO
38+
* Add rank difference, and top-1/top-5 difference from ImageNet-1k validation for the 3 additional test sets
39+
* Explore adding a reduced version of ImageNet-C (Corruptions) and ImageNet-P (Perturbations) from https://github.com/hendrycks/robustness. The originals are huge and image size specific.

results/results-all.csv

Lines changed: 0 additions & 87 deletions
This file was deleted.

0 commit comments

Comments
 (0)