Skip to content

Commit 27a93e9

Browse files
committed
Improve test crop for ViT models. Small now 77.85, added base weights at 79.35 top-1.
1 parent d4db9e7 commit 27a93e9

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

Diff for: README.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## What's New
44

5+
### Oct 21, 2020
6+
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
7+
58
### Oct 13, 2020
69
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
710
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers

Diff for: timm/models/vision_transformer.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
3939
return {
4040
'url': url,
4141
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
42-
'crop_pct': 1.0, 'interpolation': 'bicubic',
42+
'crop_pct': .9, 'interpolation': 'bicubic',
4343
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
4444
'first_conv': '', 'classifier': 'head',
4545
**kwargs
@@ -51,7 +51,9 @@ def _cfg(url='', **kwargs):
5151
'vit_small_patch16_224': _cfg(
5252
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
5353
),
54-
'vit_base_patch16_224': _cfg(),
54+
'vit_base_patch16_224': _cfg(
55+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth'
56+
),
5557
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
5658
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
5759
'vit_large_patch16_224': _cfg(),
@@ -283,6 +285,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
283285
def vit_base_patch16_224(pretrained=False, **kwargs):
284286
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
285287
model.default_cfg = default_cfgs['vit_base_patch16_224']
288+
if pretrained:
289+
load_pretrained(
290+
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
286291
return model
287292

288293

0 commit comments

Comments
 (0)