Skip to content

Commit 70ae7f0

Browse files
authored
Merge pull request #250 from rwightman/vision_transformer
Vision Transformer
2 parents 80078c4 + be53107 commit 70ae7f0

15 files changed

+434
-33
lines changed

Diff for: README.md

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## What's New
44

5+
### Oct 13, 2020
6+
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
7+
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
8+
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
9+
* Pip release, doc updates pending a few more changes...
10+
511
### Sept 18, 2020
612
* New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D
713
* Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D)
@@ -124,6 +130,7 @@ A full version of the list below with source links can be found in the [document
124130
* SelecSLS - https://arxiv.org/abs/1907.00837
125131
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
126132
* TResNet - https://arxiv.org/abs/2003.13630
133+
* Vision Transformer - https://openreview.net/forum?id=YicbFdNTTy
127134
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
128135
* Xception - https://arxiv.org/abs/1610.02357
129136
* Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
@@ -162,6 +169,8 @@ Several (less common) features that I often utilize in my projects are included.
162169
* `lookahead` adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610)
163170
* `fused<name>` optimizers by name with [NVIDIA Apex](https://github.com/NVIDIA/apex/tree/master/apex/optimizers) installed
164171
* `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) (https://arxiv.org/abs/2006.08217)
172+
* `adafactor` adapted from [FAIRSeq impl](https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py) (https://arxiv.org/abs/1804.04235)
173+
* `adahessian` by [David Samuel](https://github.com/davda54/ada-hessian) (https://arxiv.org/abs/2006.00719)
165174
* Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896)
166175
* Mixup (https://arxiv.org/abs/1710.09412)
167176
* CutMix (https://arxiv.org/abs/1905.04899)

Diff for: avg_checkpoints.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ def main():
103103
v = v.clamp(float32_info.min, float32_info.max)
104104
final_state_dict[k] = v.to(dtype=torch.float32)
105105

106-
torch.save(final_state_dict, args.output)
106+
try:
107+
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
108+
except:
109+
torch.save(final_state_dict, args.output)
110+
107111
with open(args.output, 'rb') as f:
108112
sha_hash = hashlib.sha256(f.read()).hexdigest()
109113
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))

Diff for: clean_checkpoint.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ def main():
5757
new_state_dict[name] = v
5858
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
5959

60-
torch.save(new_state_dict, _TEMP_NAME)
60+
try:
61+
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
62+
except:
63+
torch.save(new_state_dict, _TEMP_NAME)
64+
6165
with open(_TEMP_NAME, 'rb') as f:
6266
sha_hash = hashlib.sha256(f.read()).hexdigest()
6367

Diff for: tests/test_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
1717
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
18-
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d']
18+
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*']
1919
else:
20-
EXCLUDE_FILTERS = []
20+
EXCLUDE_FILTERS = ['vit_*']
2121
MAX_FWD_SIZE = 384
2222
MAX_BWD_SIZE = 128
2323
MAX_FWD_FEAT_SIZE = 448
@@ -68,7 +68,7 @@ def test_model_backward(model_name, batch_size):
6868

6969

7070
@pytest.mark.timeout(120)
71-
@pytest.mark.parametrize('model_name', list_models())
71+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=['vit_*']))
7272
@pytest.mark.parametrize('batch_size', [1])
7373
def test_model_default_cfgs(model_name, batch_size):
7474
"""Run a single forward pass with each model"""

Diff for: timm/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .senet import *
2222
from .sknet import *
2323
from .tresnet import *
24+
from .vision_transformer import *
2425
from .vovnet import *
2526
from .xception import *
2627
from .xception_aligned import *

Diff for: timm/models/layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1717
from .eca import EcaModule, CecaModule
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
19+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
1920
from .inplace_abn import InplaceAbn
2021
from .mixed_conv2d import MixedConv2d
2122
from .norm_act import BatchNormAct2d

Diff for: timm/models/layers/cond_conv2d.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch import nn as nn
1414
from torch.nn import functional as F
1515

16-
from .helpers import tup_pair
16+
from .helpers import to_2tuple
1717
from .conv2d_same import conv2d_same
1818
from .padding import get_padding_value
1919

@@ -46,13 +46,13 @@ def __init__(self, in_channels, out_channels, kernel_size=3,
4646

4747
self.in_channels = in_channels
4848
self.out_channels = out_channels
49-
self.kernel_size = tup_pair(kernel_size)
50-
self.stride = tup_pair(stride)
49+
self.kernel_size = to_2tuple(kernel_size)
50+
self.stride = to_2tuple(stride)
5151
padding_val, is_padding_dynamic = get_padding_value(
5252
padding, kernel_size, stride=stride, dilation=dilation)
5353
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
54-
self.padding = tup_pair(padding_val)
55-
self.dilation = tup_pair(dilation)
54+
self.padding = to_2tuple(padding_val)
55+
self.dilation = to_2tuple(dilation)
5656
self.groups = groups
5757
self.num_experts = num_experts
5858

Diff for: timm/models/layers/drop.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
150150
if drop_prob == 0. or not training:
151151
return x
152152
keep_prob = 1 - drop_prob
153-
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
153+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
154+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
154155
random_tensor.floor_() # binarize
155156
output = x.div(keep_prob) * random_tensor
156157
return output

Diff for: timm/models/layers/helpers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ def parse(x):
1515
return parse
1616

1717

18-
tup_single = _ntuple(1)
19-
tup_pair = _ntuple(2)
20-
tup_triple = _ntuple(3)
21-
tup_quadruple = _ntuple(4)
22-
ntup = _ntuple
18+
to_1tuple = _ntuple(1)
19+
to_2tuple = _ntuple(2)
20+
to_3tuple = _ntuple(3)
21+
to_4tuple = _ntuple(4)
22+
to_ntuple = _ntuple
2323

2424

2525

Diff for: timm/models/layers/median_pool.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
import torch.nn as nn
55
import torch.nn.functional as F
6-
from .helpers import tup_pair, tup_quadruple
6+
from .helpers import to_2tuple, to_4tuple
77

88

99
class MedianPool2d(nn.Module):
@@ -17,9 +17,9 @@ class MedianPool2d(nn.Module):
1717
"""
1818
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
1919
super(MedianPool2d, self).__init__()
20-
self.k = tup_pair(kernel_size)
21-
self.stride = tup_pair(stride)
22-
self.padding = tup_quadruple(padding) # convert to l, r, t, b
20+
self.k = to_2tuple(kernel_size)
21+
self.stride = to_2tuple(stride)
22+
self.padding = to_4tuple(padding) # convert to l, r, t, b
2323
self.same = same
2424

2525
def _padding(self, x):

Diff for: timm/models/layers/pool2d_same.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from typing import List, Tuple, Optional
99

10-
from .helpers import tup_pair
10+
from .helpers import to_2tuple
1111
from .padding import pad_same, get_padding_value
1212

1313

@@ -22,8 +22,8 @@ class AvgPool2dSame(nn.AvgPool2d):
2222
""" Tensorflow like 'SAME' wrapper for 2D average pooling
2323
"""
2424
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25-
kernel_size = tup_pair(kernel_size)
26-
stride = tup_pair(stride)
25+
kernel_size = to_2tuple(kernel_size)
26+
stride = to_2tuple(stride)
2727
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
2828

2929
def forward(self, x):
@@ -42,9 +42,9 @@ class MaxPool2dSame(nn.MaxPool2d):
4242
""" Tensorflow like 'SAME' wrapper for 2D max pooling
4343
"""
4444
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
45-
kernel_size = tup_pair(kernel_size)
46-
stride = tup_pair(stride)
47-
dilation = tup_pair(dilation)
45+
kernel_size = to_2tuple(kernel_size)
46+
stride = to_2tuple(stride)
47+
dilation = to_2tuple(dilation)
4848
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
4949

5050
def forward(self, x):

Diff for: timm/models/rexnet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .helpers import build_model_with_cfg
1818
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
1919
from .registry import register_model
20+
from .efficientnet_builder import efficientnet_init_weights
2021

2122

2223
def _cfg(url=''):
@@ -186,7 +187,7 @@ def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_strid
186187

187188
self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate)
188189

189-
# FIXME weight init, the original appears to use PyTorch defaults
190+
efficientnet_init_weights(self)
190191

191192
def get_classifier(self):
192193
return self.head.fc

0 commit comments

Comments
 (0)