29
29
import torch .nn as nn
30
30
31
31
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
32
- from .helpers import build_model_with_cfg
32
+ from .helpers import load_pretrained
33
33
from .layers import DropPath , to_2tuple , trunc_normal_
34
34
from .resnet import resnet26d , resnet50d
35
35
from .registry import register_model
@@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
48
48
49
49
default_cfgs = {
50
50
# patch models
51
- 'vit_small_patch16_224' : _cfg (),
51
+ 'vit_small_patch16_224' : _cfg (
52
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth' ,
53
+ ),
52
54
'vit_base_patch16_224' : _cfg (),
53
55
'vit_base_patch16_384' : _cfg (input_size = (3 , 384 , 384 )),
54
56
'vit_base_patch32_384' : _cfg (input_size = (3 , 384 , 384 )),
@@ -271,6 +273,9 @@ def forward(self, x, attn_mask=None):
271
273
def vit_small_patch16_224 (pretrained = False , ** kwargs ):
272
274
model = VisionTransformer (patch_size = 16 , embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3. , ** kwargs )
273
275
model .default_cfg = default_cfgs ['vit_small_patch16_224' ]
276
+ if pretrained :
277
+ load_pretrained (
278
+ model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
274
279
return model
275
280
276
281
0 commit comments