Skip to content

Commit d4db9e7

Browse files
committed
Add small vision transformer weights. 77.42 top-1.
1 parent ccfb575 commit d4db9e7

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

Diff for: timm/models/vision_transformer.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch.nn as nn
3030

3131
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
32-
from .helpers import build_model_with_cfg
32+
from .helpers import load_pretrained
3333
from .layers import DropPath, to_2tuple, trunc_normal_
3434
from .resnet import resnet26d, resnet50d
3535
from .registry import register_model
@@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
4848

4949
default_cfgs = {
5050
# patch models
51-
'vit_small_patch16_224': _cfg(),
51+
'vit_small_patch16_224': _cfg(
52+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
53+
),
5254
'vit_base_patch16_224': _cfg(),
5355
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
5456
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
@@ -271,6 +273,9 @@ def forward(self, x, attn_mask=None):
271273
def vit_small_patch16_224(pretrained=False, **kwargs):
272274
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
273275
model.default_cfg = default_cfgs['vit_small_patch16_224']
276+
if pretrained:
277+
load_pretrained(
278+
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
274279
return model
275280

276281

0 commit comments

Comments
 (0)