Skip to content

Commit e41125c

Browse files
authored
Merge pull request #2209 from huggingface/fcossio-vit-maxpool
ViT pooling refactor
2 parents a224668 + 6254dfa commit e41125c

18 files changed

+73
-34
lines changed

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ torch>=1.7
22
torchvision
33
pyyaml
44
huggingface_hub
5-
safetensors>=0.2
5+
safetensors>=0.2
6+
numpy<2.0

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True):
156156
def get_classifier(self) -> nn.Module:
157157
return self.classifier
158158

159-
def reset_classifier(self, num_classes, global_pool='avg'):
159+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
160160
self.num_classes = num_classes
161161
self.global_pool, self.classifier = create_classifier(
162162
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True):
273273
def get_classifier(self) -> nn.Module:
274274
return self.classifier
275275

276-
def reset_classifier(self, num_classes, global_pool='avg'):
276+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
277277
self.num_classes = num_classes
278278
# cannot meaningfully change pooling of efficient head after creation
279279
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)

timm/models/hrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
739739
def get_classifier(self) -> nn.Module:
740740
return self.classifier
741741

742-
def reset_classifier(self, num_classes, global_pool='avg'):
742+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
743743
self.num_classes = num_classes
744744
self.global_pool, self.classifier = create_classifier(
745745
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/inception_v4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True):
280280
def get_classifier(self) -> nn.Module:
281281
return self.last_linear
282282

283-
def reset_classifier(self, num_classes, global_pool='avg'):
283+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
284284
self.num_classes = num_classes
285285
self.global_pool, self.last_linear = create_classifier(
286286
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/metaformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
2828

29-
3029
from collections import OrderedDict
3130
from functools import partial
31+
from typing import Optional
3232

3333
import torch
3434
import torch.nn as nn
@@ -548,7 +548,7 @@ def __init__(
548548
# if using MlpHead, dropout is handled by MlpHead
549549
if num_classes > 0:
550550
if self.use_mlp_head:
551-
# FIXME hidden size
551+
# FIXME not actually returning mlp hidden state right now as pre-logits.
552552
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
553553
self.head_hidden_size = self.num_features
554554
else:
@@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True):
583583
def get_classifier(self) -> nn.Module:
584584
return self.head.fc
585585

586-
def reset_classifier(self, num_classes=0, global_pool=None):
586+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
587587
if global_pool is not None:
588588
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
589589
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()

timm/models/nasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True):
518518
def get_classifier(self) -> nn.Module:
519519
return self.last_linear
520520

521-
def reset_classifier(self, num_classes, global_pool='avg'):
521+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
522522
self.num_classes = num_classes
523523
self.global_pool, self.last_linear = create_classifier(
524524
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/pnasnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True):
307307
def get_classifier(self) -> nn.Module:
308308
return self.last_linear
309309

310-
def reset_classifier(self, num_classes, global_pool='avg'):
310+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
311311
self.num_classes = num_classes
312312
self.global_pool, self.last_linear = create_classifier(
313313
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/regnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True):
514514
def get_classifier(self) -> nn.Module:
515515
return self.head.fc
516516

517-
def reset_classifier(self, num_classes, global_pool='avg'):
517+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
518518
self.head.reset(num_classes, pool_type=global_pool)
519519

520520
def forward_intermediates(

timm/models/rexnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from functools import partial
1414
from math import ceil
15+
from typing import Optional
1516

1617
import torch
1718
import torch.nn as nn
@@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True):
229230
def get_classifier(self) -> nn.Module:
230231
return self.head.fc
231232

232-
def reset_classifier(self, num_classes, global_pool='avg'):
233+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
233234
self.num_classes = num_classes
234235
self.head.reset(num_classes, global_pool)
235236

timm/models/selecsls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True):
161161
def get_classifier(self) -> nn.Module:
162162
return self.fc
163163

164-
def reset_classifier(self, num_classes, global_pool='avg'):
164+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
165165
self.num_classes = num_classes
166166
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
167167

timm/models/senet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True):
337337
def get_classifier(self) -> nn.Module:
338338
return self.last_linear
339339

340-
def reset_classifier(self, num_classes, global_pool='avg'):
340+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
341341
self.num_classes = num_classes
342342
self.global_pool, self.last_linear = create_classifier(
343343
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/vision_transformer.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
386386
return self._forward(x)
387387

388388

389+
def global_pool_nlc(
390+
x: torch.Tensor,
391+
pool_type: str = 'token',
392+
num_prefix_tokens: int = 1,
393+
reduce_include_prefix: bool = False,
394+
):
395+
if not pool_type:
396+
return x
397+
398+
if pool_type == 'token':
399+
x = x[:, 0] # class token
400+
else:
401+
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
402+
if pool_type == 'avg':
403+
x = x.mean(dim=1)
404+
elif pool_type == 'avgmax':
405+
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
406+
elif pool_type == 'max':
407+
x = x.amax(dim=1)
408+
else:
409+
assert not pool_type, f'Unknown pool type {pool_type}'
410+
411+
return x
412+
413+
389414
class VisionTransformer(nn.Module):
390415
""" Vision Transformer
391416
@@ -400,7 +425,7 @@ def __init__(
400425
patch_size: Union[int, Tuple[int, int]] = 16,
401426
in_chans: int = 3,
402427
num_classes: int = 1000,
403-
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
428+
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
404429
embed_dim: int = 768,
405430
depth: int = 12,
406431
num_heads: int = 12,
@@ -459,10 +484,10 @@ def __init__(
459484
block_fn: Transformer block layer.
460485
"""
461486
super().__init__()
462-
assert global_pool in ('', 'avg', 'token', 'map')
487+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
463488
assert class_token or global_pool != 'token'
464489
assert pos_embed in ('', 'none', 'learn')
465-
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
490+
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
466491
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
467492
act_layer = get_act_layer(act_layer) or nn.GELU
468493

@@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
596621
def get_classifier(self) -> nn.Module:
597622
return self.head
598623

599-
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
624+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
600625
self.num_classes = num_classes
601626
if global_pool is not None:
602-
assert global_pool in ('', 'avg', 'token', 'map')
627+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
603628
if global_pool == 'map' and self.attn_pool is None:
604629
assert False, "Cannot currently add attention pooling in reset_classifier()."
605630
elif global_pool != 'map ' and self.attn_pool is not None:
@@ -756,13 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
756781
x = self.norm(x)
757782
return x
758783

759-
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
784+
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
760785
if self.attn_pool is not None:
761786
x = self.attn_pool(x)
762-
elif self.global_pool == 'avg':
763-
x = x[:, self.num_prefix_tokens:].mean(dim=1)
764-
elif self.global_pool:
765-
x = x[:, 0] # class token
787+
return x
788+
pool_type = self.global_pool if pool_type is None else pool_type
789+
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
790+
return x
791+
792+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
793+
x = self.pool(x)
766794
x = self.fc_norm(x)
767795
x = self.head_drop(x)
768796
return x if pre_logits else self.head(x)

timm/models/vision_transformer_relpos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True):
381381
def get_classifier(self) -> nn.Module:
382382
return self.head
383383

384-
def reset_classifier(self, num_classes: int, global_pool=None):
384+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
385385
self.num_classes = num_classes
386386
if global_pool is not None:
387387
assert global_pool in ('', 'avg', 'token')

timm/models/vision_transformer_sam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True):
536536
def get_classifier(self) -> nn.Module:
537537
return self.head
538538

539-
def reset_classifier(self, num_classes=0, global_pool=None):
539+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
540540
self.head.reset(num_classes, global_pool)
541541

542542
def forward_intermediates(

timm/models/vovnet.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Hacked together by / Copyright 2020 Ross Wightman
1212
"""
1313

14-
from typing import List
14+
from typing import List, Optional
1515

1616
import torch
1717
import torch.nn as nn
@@ -134,9 +134,17 @@ def __init__(
134134
else:
135135
drop_path = None
136136
blocks += [OsaBlock(
137-
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
138-
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
139-
]
137+
in_chs,
138+
mid_chs,
139+
out_chs,
140+
layer_per_block,
141+
residual=residual and i > 0,
142+
depthwise=depthwise,
143+
attn=attn if last_block else '',
144+
norm_layer=norm_layer,
145+
act_layer=act_layer,
146+
drop_path=drop_path
147+
)]
140148
in_chs = out_chs
141149
self.blocks = nn.Sequential(*blocks)
142150

@@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True):
252260
def get_classifier(self) -> nn.Module:
253261
return self.head.fc
254262

255-
def reset_classifier(self, num_classes, global_pool='avg'):
256-
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
263+
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
264+
self.num_classes = num_classes
265+
self.head.reset(num_classes, global_pool)
257266

258267
def forward_features(self, x):
259268
x = self.stem(x)

timm/models/xception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True):
174174
def get_classifier(self) -> nn.Module:
175175
return self.fc
176176

177-
def reset_classifier(self, num_classes, global_pool='avg'):
177+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
178178
self.num_classes = num_classes
179179
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
180180

timm/models/xception_aligned.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True):
274274
def get_classifier(self) -> nn.Module:
275275
return self.head.fc
276276

277-
def reset_classifier(self, num_classes, global_pool='avg'):
277+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
278278
self.head.reset(num_classes, pool_type=global_pool)
279279

280280
def forward_features(self, x):

0 commit comments

Comments
 (0)