Skip to content

Commit f31933c

Browse files
committed
Initial Vision Transformer impl w/ patch and hybrid variants. Refactor tuple helpers.
1 parent 9305313 commit f31933c

File tree

11 files changed

+408
-27
lines changed

11 files changed

+408
-27
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
1717
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
18-
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d']
18+
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*']
1919
else:
20-
EXCLUDE_FILTERS = []
20+
EXCLUDE_FILTERS = ['vit_*']
2121
MAX_FWD_SIZE = 384
2222
MAX_BWD_SIZE = 128
2323
MAX_FWD_FEAT_SIZE = 448

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .senet import *
2222
from .sknet import *
2323
from .tresnet import *
24+
from .vision_transformer import *
2425
from .vovnet import *
2526
from .xception import *
2627
from .xception_aligned import *

timm/models/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1717
from .eca import EcaModule, CecaModule
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
19+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
1920
from .inplace_abn import InplaceAbn
2021
from .mixed_conv2d import MixedConv2d
2122
from .norm_act import BatchNormAct2d

timm/models/layers/cond_conv2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch import nn as nn
1414
from torch.nn import functional as F
1515

16-
from .helpers import tup_pair
16+
from .helpers import to_2tuple
1717
from .conv2d_same import conv2d_same
1818
from .padding import get_padding_value
1919

@@ -46,13 +46,13 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
4646

4747
self.in_channels = in_channels
4848
self.out_channels = out_channels
49-
self.kernel_size = tup_pair(kernel_size)
50-
self.stride = tup_pair(stride)
49+
self.kernel_size = to_2tuple(kernel_size)
50+
self.stride = to_2tuple(stride)
5151
padding_val, is_padding_dynamic = get_padding_value(
5252
padding, kernel_size, stride=stride, dilation=dilation)
5353
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
54-
self.padding = tup_pair(padding_val)
55-
self.dilation = tup_pair(dilation)
54+
self.padding = to_2tuple(padding_val)
55+
self.dilation = to_2tuple(dilation)
5656
self.groups = groups
5757
self.num_experts = num_experts
5858

timm/models/layers/drop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
150150
if drop_prob == 0. or not training:
151151
return x
152152
keep_prob = 1 - drop_prob
153-
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
153+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
154+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
154155
random_tensor.floor_() # binarize
155156
output = x.div(keep_prob) * random_tensor
156157
return output

timm/models/layers/helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ def parse(x):
1515
return parse
1616

1717

18-
tup_single = _ntuple(1)
19-
tup_pair = _ntuple(2)
20-
tup_triple = _ntuple(3)
21-
tup_quadruple = _ntuple(4)
22-
ntup = _ntuple
18+
to_1tuple = _ntuple(1)
19+
to_2tuple = _ntuple(2)
20+
to_3tuple = _ntuple(3)
21+
to_4tuple = _ntuple(4)
22+
to_ntuple = _ntuple
2323

2424

2525

timm/models/layers/median_pool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
import torch.nn as nn
55
import torch.nn.functional as F
6-
from .helpers import tup_pair, tup_quadruple
6+
from .helpers import to_2tuple, to_4tuple
77

88

99
class MedianPool2d(nn.Module):
@@ -17,9 +17,9 @@ class MedianPool2d(nn.Module):
1717
"""
1818
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
1919
super(MedianPool2d, self).__init__()
20-
self.k = tup_pair(kernel_size)
21-
self.stride = tup_pair(stride)
22-
self.padding = tup_quadruple(padding) # convert to l, r, t, b
20+
self.k = to_2tuple(kernel_size)
21+
self.stride = to_2tuple(stride)
22+
self.padding = to_4tuple(padding) # convert to l, r, t, b
2323
self.same = same
2424

2525
def _padding(self, x):

timm/models/layers/pool2d_same.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from typing import List, Tuple, Optional
99

10-
from .helpers import tup_pair
10+
from .helpers import to_2tuple
1111
from .padding import pad_same, get_padding_value
1212

1313

@@ -22,8 +22,8 @@ class AvgPool2dSame(nn.AvgPool2d):
2222
""" Tensorflow like 'SAME' wrapper for 2D average pooling
2323
"""
2424
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25-
kernel_size = tup_pair(kernel_size)
26-
stride = tup_pair(stride)
25+
kernel_size = to_2tuple(kernel_size)
26+
stride = to_2tuple(stride)
2727
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
2828

2929
def forward(self, x):
@@ -42,9 +42,9 @@ class MaxPool2dSame(nn.MaxPool2d):
4242
""" Tensorflow like 'SAME' wrapper for 2D max pooling
4343
"""
4444
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
45-
kernel_size = tup_pair(kernel_size)
46-
stride = tup_pair(stride)
47-
dilation = tup_pair(dilation)
45+
kernel_size = to_2tuple(kernel_size)
46+
stride = to_2tuple(stride)
47+
dilation = to_2tuple(dilation)
4848
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
4949

5050
def forward(self, x):

timm/models/rexnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .helpers import build_model_with_cfg
1818
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
1919
from .registry import register_model
20+
from .efficientnet_builder import efficientnet_init_weights
2021

2122

2223
def _cfg(url=''):
@@ -186,7 +187,7 @@ def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_strid
186187

187188
self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate)
188189

189-
# FIXME weight init, the original appears to use PyTorch defaults
190+
efficientnet_init_weights(self)
190191

191192
def get_classifier(self):
192193
return self.head.fc

0 commit comments

Comments
 (0)