Skip to content

Commit 6e5553d

Browse files
authored
Add ConvNeXt-V2 support (model additions and weights) (huggingface#1614)
* Add ConvNeXt-V2 support (model additions and weights) * ConvNeXt-V2 weights on HF Hub, tweaking some tests * Update README, fixing convnextv2 tests
1 parent 3698e79 commit 6e5553d

File tree

6 files changed

+386
-28
lines changed

6 files changed

+386
-28
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly
2828
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
2929
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏
3030

31+
### Jan 5, 2023
32+
* ConvNeXt-V2 models and weights added to existing `convnext.py`
33+
* Paper: [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808)
34+
* Reference impl: https://github.com/facebookresearch/ConvNeXt-V2 (NOTE: weights currently CC-BY-NC)
35+
3136
### Dec 23, 2022 🎄☃
3237
* Add FlexiViT models and weights from https://github.com/google-research/big_vision (check out paper at https://arxiv.org/abs/2212.08013)
3338
* NOTE currently resizing is static on model creation, on-the-fly dynamic / train patch size sampling is a WIP
@@ -396,6 +401,7 @@ A full version of the list below with source links can be found in the [document
396401
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
397402
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
398403
* ConvNeXt - https://arxiv.org/abs/2201.03545
404+
* ConvNeXt-V2 - http://arxiv.org/abs/2301.00808
399405
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
400406
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
401407
* DeiT - https://arxiv.org/abs/2012.12877
@@ -418,6 +424,7 @@ A full version of the list below with source links can be found in the [document
418424
* Single-Path NAS - https://arxiv.org/abs/1904.02877
419425
* TinyNet - https://arxiv.org/abs/2010.14819
420426
* EVA - https://arxiv.org/abs/2211.07636
427+
* FlexiViT - https://arxiv.org/abs/2212.08013
421428
* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959
422429
* GhostNet - https://arxiv.org/abs/1911.11907
423430
* gMLP - https://arxiv.org/abs/2105.08050

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
3939
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
4040
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
41-
'swin*giant*']
41+
'swin*giant*', 'convnextv2_huge*']
4242
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
4343
else:
4444
EXCLUDE_FILTERS = []
@@ -129,7 +129,7 @@ def test_model_backward(model_name, batch_size):
129129

130130

131131
@pytest.mark.timeout(300)
132-
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
132+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
133133
@pytest.mark.parametrize('batch_size', [1])
134134
def test_model_default_cfgs(model_name, batch_size):
135135
"""Run a single forward pass with each model"""
@@ -191,7 +191,7 @@ def test_model_default_cfgs(model_name, batch_size):
191191

192192

193193
@pytest.mark.timeout(300)
194-
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS))
194+
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
195195
@pytest.mark.parametrize('batch_size', [1])
196196
def test_model_default_cfgs_non_std(model_name, batch_size):
197197
"""Run a single forward pass with each model"""
@@ -304,7 +304,7 @@ def test_model_forward_torchscript(model_name, batch_size):
304304

305305

306306
@pytest.mark.timeout(120)
307-
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
307+
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
308308
@pytest.mark.parametrize('batch_size', [1])
309309
def test_model_forward_features(model_name, batch_size):
310310
"""Run a single forward pass with each model in feature extraction mode"""

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .inplace_abn import InplaceAbn
2727
from .linear import Linear
2828
from .mixed_conv2d import MixedConv2d
29-
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
29+
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
3030
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
3131
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
3232
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm

timm/layers/grn.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
""" Global Response Normalization Module
2+
3+
Based on the GRN layer presented in
4+
`ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
5+
6+
This implementation
7+
* works for both NCHW and NHWC tensor layouts
8+
* uses affine param names matching existing torch norm layers
9+
* slightly improves eager mode performance via fused addcmul
10+
11+
Hacked together by / Copyright 2023 Ross Wightman
12+
"""
13+
14+
import torch
15+
from torch import nn as nn
16+
17+
18+
class GlobalResponseNorm(nn.Module):
19+
""" Global Response Normalization layer
20+
"""
21+
def __init__(self, dim, eps=1e-6, channels_last=True):
22+
super().__init__()
23+
self.eps = eps
24+
if channels_last:
25+
self.spatial_dim = (1, 2)
26+
self.channel_dim = -1
27+
self.wb_shape = (1, 1, 1, -1)
28+
else:
29+
self.spatial_dim = (2, 3)
30+
self.channel_dim = 1
31+
self.wb_shape = (1, -1, 1, 1)
32+
33+
self.weight = nn.Parameter(torch.zeros(dim))
34+
self.bias = nn.Parameter(torch.zeros(dim))
35+
36+
def forward(self, x):
37+
x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
38+
x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
39+
return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)

timm/layers/mlp.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,38 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5+
from functools import partial
6+
57
from torch import nn as nn
68

9+
from .grn import GlobalResponseNorm
710
from .helpers import to_2tuple
811

912

1013
class Mlp(nn.Module):
1114
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
1215
"""
13-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
16+
def __init__(
17+
self,
18+
in_features,
19+
hidden_features=None,
20+
out_features=None,
21+
act_layer=nn.GELU,
22+
bias=True,
23+
drop=0.,
24+
use_conv=False,
25+
):
1426
super().__init__()
1527
out_features = out_features or in_features
1628
hidden_features = hidden_features or in_features
1729
bias = to_2tuple(bias)
1830
drop_probs = to_2tuple(drop)
31+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
1932

20-
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
33+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
2134
self.act = act_layer()
2235
self.drop1 = nn.Dropout(drop_probs[0])
23-
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
36+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
2437
self.drop2 = nn.Dropout(drop_probs[1])
2538

2639
def forward(self, x):
@@ -36,18 +49,29 @@ class GluMlp(nn.Module):
3649
""" MLP w/ GLU style gating
3750
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
3851
"""
39-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
52+
def __init__(
53+
self,
54+
in_features,
55+
hidden_features=None,
56+
out_features=None,
57+
act_layer=nn.Sigmoid,
58+
bias=True,
59+
drop=0.,
60+
use_conv=False,
61+
):
4062
super().__init__()
4163
out_features = out_features or in_features
4264
hidden_features = hidden_features or in_features
4365
assert hidden_features % 2 == 0
4466
bias = to_2tuple(bias)
4567
drop_probs = to_2tuple(drop)
68+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
69+
self.chunk_dim = 1 if use_conv else -1
4670

47-
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
71+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
4872
self.act = act_layer()
4973
self.drop1 = nn.Dropout(drop_probs[0])
50-
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
74+
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
5175
self.drop2 = nn.Dropout(drop_probs[1])
5276

5377
def init_weights(self):
@@ -58,7 +82,7 @@ def init_weights(self):
5882

5983
def forward(self, x):
6084
x = self.fc1(x)
61-
x, gates = x.chunk(2, dim=-1)
85+
x, gates = x.chunk(2, dim=self.chunk_dim)
6286
x = x * self.act(gates)
6387
x = self.drop1(x)
6488
x = self.fc2(x)
@@ -70,8 +94,15 @@ class GatedMlp(nn.Module):
7094
""" MLP as used in gMLP
7195
"""
7296
def __init__(
73-
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
74-
gate_layer=None, bias=True, drop=0.):
97+
self,
98+
in_features,
99+
hidden_features=None,
100+
out_features=None,
101+
act_layer=nn.GELU,
102+
gate_layer=None,
103+
bias=True,
104+
drop=0.,
105+
):
75106
super().__init__()
76107
out_features = out_features or in_features
77108
hidden_features = hidden_features or in_features
@@ -104,8 +135,15 @@ class ConvMlp(nn.Module):
104135
""" MLP using 1x1 convs that keeps spatial dims
105136
"""
106137
def __init__(
107-
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
108-
norm_layer=None, bias=True, drop=0.):
138+
self,
139+
in_features,
140+
hidden_features=None,
141+
out_features=None,
142+
act_layer=nn.ReLU,
143+
norm_layer=None,
144+
bias=True,
145+
drop=0.,
146+
):
109147
super().__init__()
110148
out_features = out_features or in_features
111149
hidden_features = hidden_features or in_features
@@ -124,3 +162,40 @@ def forward(self, x):
124162
x = self.drop(x)
125163
x = self.fc2(x)
126164
return x
165+
166+
167+
class GlobalResponseNormMlp(nn.Module):
168+
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
169+
"""
170+
def __init__(
171+
self,
172+
in_features,
173+
hidden_features=None,
174+
out_features=None,
175+
act_layer=nn.GELU,
176+
bias=True,
177+
drop=0.,
178+
use_conv=False,
179+
):
180+
super().__init__()
181+
out_features = out_features or in_features
182+
hidden_features = hidden_features or in_features
183+
bias = to_2tuple(bias)
184+
drop_probs = to_2tuple(drop)
185+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
186+
187+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
188+
self.act = act_layer()
189+
self.drop1 = nn.Dropout(drop_probs[0])
190+
self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)
191+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
192+
self.drop2 = nn.Dropout(drop_probs[1])
193+
194+
def forward(self, x):
195+
x = self.fc1(x)
196+
x = self.act(x)
197+
x = self.drop1(x)
198+
x = self.grn(x)
199+
x = self.fc2(x)
200+
x = self.drop2(x)
201+
return x

0 commit comments

Comments
 (0)