Skip to content

Commit c719f7e

Browse files
committed
More forward_intermediates() updates
* add convnext, resnet, efficientformer, levit support * remove kwargs only for fn so that torchscript isn't broken for all :( * use reset_classifier() consistently in prune
1 parent 301d0bb commit c719f7e

21 files changed

+436
-139
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
FEAT_INTER_FILTERS = [
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
54-
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3'
54+
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
5555
]
5656

5757
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@@ -429,7 +429,7 @@ def test_model_forward_intermediates(model_name, batch_size):
429429
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
430430
expected_channels = feature_info.channels()
431431
expected_reduction = feature_info.reduction()
432-
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
432+
assert len(expected_channels) >= 3 # all models here should have at least 3 feature levels
433433

434434
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
435435
if max(input_size) > MAX_FFEAT_SIZE:

timm/models/beit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,6 @@ def reset_classifier(self, num_classes, global_pool=None):
404404
def forward_intermediates(
405405
self,
406406
x: torch.Tensor,
407-
*,
408407
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
409408
return_prefix_tokens: bool = False,
410409
norm: bool = False,
@@ -425,7 +424,7 @@ def forward_intermediates(
425424
Returns:
426425
427426
"""
428-
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
427+
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
429428
reshape = output_fmt == 'NCHW'
430429
intermediates = []
431430
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@@ -437,6 +436,7 @@ def forward_intermediates(
437436
if self.pos_embed is not None:
438437
x = x + self.pos_embed
439438
x = self.pos_drop(x)
439+
440440
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
441441
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
442442
blocks = self.blocks
@@ -482,7 +482,7 @@ def prune_intermediate_layers(
482482
self.norm = nn.Identity()
483483
if prune_head:
484484
self.fc_norm = nn.Identity()
485-
self.head = nn.Identity()
485+
self.reset_classifier(0, '')
486486
return take_indices
487487

488488
def forward_features(self, x):

timm/models/cait.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ def reset_classifier(self, num_classes, global_pool=None):
341341
def forward_intermediates(
342342
self,
343343
x: torch.Tensor,
344-
*,
345344
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
346345
norm: bool = False,
347346
stop_early: bool = False,
@@ -358,7 +357,7 @@ def forward_intermediates(
358357
output_fmt: Shape of intermediate feature outputs
359358
intermediates_only: Only return intermediate features
360359
"""
361-
assert output_fmt in ('NCHW', 'NLC'), 'Output format for ViT features must be one of NCHW or NLC.'
360+
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
362361
reshape = output_fmt == 'NCHW'
363362
intermediates = []
364363
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
@@ -368,6 +367,7 @@ def forward_intermediates(
368367
x = self.patch_embed(x)
369368
x = x + self.pos_embed
370369
x = self.pos_drop(x)
370+
371371
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
372372
blocks = self.blocks
373373
else:
@@ -410,7 +410,7 @@ def prune_intermediate_layers(
410410
self.norm = nn.Identity()
411411
if prune_head:
412412
self.blocks_token_only = nn.ModuleList() # prune token blocks with head
413-
self.head = nn.Identity()
413+
self.reset_classifier(0, '')
414414
return take_indices
415415

416416
def forward_features(self, x):

timm/models/convnext.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from collections import OrderedDict
4141
from functools import partial
42-
from typing import Callable, Optional, Tuple, Union
42+
from typing import Callable, List, Optional, Tuple, Union
4343

4444
import torch
4545
import torch.nn as nn
@@ -49,6 +49,7 @@
4949
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
5050
from timm.layers import NormMlpClassifierHead, ClassifierHead
5151
from ._builder import build_model_with_cfg
52+
from ._features import feature_take_indices
5253
from ._manipulate import named_apply, checkpoint_seq
5354
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
5455

@@ -407,6 +408,71 @@ def get_classifier(self):
407408
def reset_classifier(self, num_classes=0, global_pool=None):
408409
self.head.reset(num_classes, global_pool)
409410

411+
def forward_intermediates(
412+
self,
413+
x: torch.Tensor,
414+
indices: Union[int, List[int], Tuple[int]] = None,
415+
norm: bool = False,
416+
stop_early: bool = False,
417+
output_fmt: str = 'NCHW',
418+
intermediates_only: bool = False,
419+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
420+
""" Forward features that returns intermediates.
421+
422+
Args:
423+
x: Input image tensor
424+
indices: Take last n blocks if int, all if None, select matching indices if sequence
425+
norm: Apply norm layer to compatible intermediates
426+
stop_early: Stop iterating over blocks when last desired intermediate hit
427+
output_fmt: Shape of intermediate feature outputs
428+
intermediates_only: Only return intermediate features
429+
Returns:
430+
431+
"""
432+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
433+
intermediates = []
434+
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
435+
436+
# forward pass
437+
feat_idx = 0 # stem is index 0
438+
x = self.stem(x)
439+
if feat_idx in take_indices:
440+
intermediates.append(x)
441+
442+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
443+
stages = self.stages
444+
else:
445+
stages = self.stages[:max_index]
446+
for stage in stages:
447+
feat_idx += 1
448+
x = stage(x)
449+
if feat_idx in take_indices:
450+
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
451+
intermediates.append(x)
452+
453+
if intermediates_only:
454+
return intermediates
455+
456+
x = self.norm_pre(x)
457+
458+
return x, intermediates
459+
460+
def prune_intermediate_layers(
461+
self,
462+
indices: Union[int, List[int], Tuple[int]] = 1,
463+
prune_norm: bool = False,
464+
prune_head: bool = True,
465+
):
466+
""" Prune layers not required for specified intermediates.
467+
"""
468+
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
469+
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
470+
if prune_norm:
471+
self.norm_pre = nn.Identity()
472+
if prune_head:
473+
self.reset_classifier(0, '')
474+
return take_indices
475+
410476
def forward_features(self, x):
411477
x = self.stem(x)
412478
x = self.stages(x)

timm/models/efficientformer.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
1313
Modifications and timm support by / Copyright 2022, Ross Wightman
1414
"""
15-
from typing import Dict
15+
from typing import Dict, List, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
1919

2020
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2121
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
2222
from ._builder import build_model_with_cfg
23+
from ._features import feature_take_indices
2324
from ._manipulate import checkpoint_seq
2425
from ._registry import generate_default_cfgs, register_model
2526

@@ -382,16 +383,19 @@ def __init__(
382383
prev_dim = embed_dims[0]
383384

384385
# stochastic depth decay rule
386+
self.num_stages = len(depths)
387+
last_stage = self.num_stages - 1
385388
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
386-
downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
389+
downsamples = downsamples or (False,) + (True,) * (self.num_stages - 1)
387390
stages = []
388-
for i in range(len(depths)):
391+
self.feature_info = []
392+
for i in range(self.num_stages):
389393
stage = EfficientFormerStage(
390394
prev_dim,
391395
embed_dims[i],
392396
depths[i],
393397
downsample=downsamples[i],
394-
num_vit=num_vit if i == 3 else 0,
398+
num_vit=num_vit if i == last_stage else 0,
395399
pool_size=pool_size,
396400
mlp_ratio=mlp_ratios,
397401
act_layer=act_layer,
@@ -403,7 +407,7 @@ def __init__(
403407
)
404408
prev_dim = embed_dims[i]
405409
stages.append(stage)
406-
410+
self.feature_info += [dict(num_chs=embed_dims[i], reduction=2**(1+i), module=f'stages.{i}')]
407411
self.stages = nn.Sequential(*stages)
408412

409413
# Classifier head
@@ -456,6 +460,76 @@ def reset_classifier(self, num_classes, global_pool=None):
456460
def set_distilled_training(self, enable=True):
457461
self.distilled_training = enable
458462

463+
def forward_intermediates(
464+
self,
465+
x: torch.Tensor,
466+
indices: Union[int, List[int], Tuple[int]] = None,
467+
norm: bool = False,
468+
stop_early: bool = False,
469+
output_fmt: str = 'NCHW',
470+
intermediates_only: bool = False,
471+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
472+
""" Forward features that returns intermediates.
473+
474+
Args:
475+
x: Input image tensor
476+
indices: Take last n blocks if int, all if None, select matching indices if sequence
477+
norm: Apply norm layer to compatible intermediates
478+
stop_early: Stop iterating over blocks when last desired intermediate hit
479+
output_fmt: Shape of intermediate feature outputs
480+
intermediates_only: Only return intermediate features
481+
Returns:
482+
483+
"""
484+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
485+
intermediates = []
486+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
487+
488+
# forward pass
489+
x = self.stem(x)
490+
B, C, H, W = x.shape
491+
492+
last_idx = self.num_stages - 1
493+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
494+
stages = self.stages
495+
else:
496+
stages = self.stages[:max_index + 1]
497+
feat_idx = 0
498+
for feat_idx, stage in enumerate(stages):
499+
x = stage(x)
500+
if feat_idx < last_idx:
501+
B, C, H, W = x.shape
502+
if feat_idx in take_indices:
503+
if feat_idx == last_idx:
504+
x_inter = self.norm(x) if norm else x
505+
intermediates.append(x_inter.reshape(B, H // 2, W // 2, -1).permute(0, 3, 1, 2))
506+
else:
507+
intermediates.append(x)
508+
509+
if intermediates_only:
510+
return intermediates
511+
512+
if feat_idx == last_idx:
513+
x = self.norm(x)
514+
515+
return x, intermediates
516+
517+
def prune_intermediate_layers(
518+
self,
519+
indices: Union[int, List[int], Tuple[int]] = 1,
520+
prune_norm: bool = False,
521+
prune_head: bool = True,
522+
):
523+
""" Prune layers not required for specified intermediates.
524+
"""
525+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
526+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
527+
if prune_norm:
528+
self.norm = nn.Identity()
529+
if prune_head:
530+
self.reset_classifier(0, '')
531+
return take_indices
532+
459533
def forward_features(self, x):
460534
x = self.stem(x)
461535
x = self.stages(x)
@@ -534,13 +608,13 @@ def _cfg(url='', **kwargs):
534608

535609

536610
def _create_efficientformer(variant, pretrained=False, **kwargs):
537-
if kwargs.get('features_only', None):
538-
raise RuntimeError('features_only not implemented for EfficientFormer models.')
539-
611+
out_indices = kwargs.pop('out_indices', 4)
540612
model = build_model_with_cfg(
541613
EfficientFormer, variant, pretrained,
542614
pretrained_filter_fn=_checkpoint_filter_fn,
543-
**kwargs)
615+
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
616+
**kwargs,
617+
)
544618
return model
545619

546620

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def reset_classifier(self, num_classes, global_pool='avg'):
162162
def forward_intermediates(
163163
self,
164164
x: torch.Tensor,
165-
*,
166165
indices: Union[int, List[int], Tuple[int]] = None,
167166
norm: bool = False,
168167
stop_early: bool = False,
@@ -199,6 +198,7 @@ def forward_intermediates(
199198
x = self.bn1(x)
200199
if feat_idx in take_indices:
201200
intermediates.append(x)
201+
202202
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
203203
blocks = self.blocks
204204
else:

timm/models/eva.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,6 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
587587
def forward_intermediates(
588588
self,
589589
x: torch.Tensor,
590-
*,
591590
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
592591
return_prefix_tokens: bool = False,
593592
norm: bool = False,
@@ -657,7 +656,7 @@ def prune_intermediate_layers(
657656
self.norm = nn.Identity()
658657
if prune_head:
659658
self.fc_norm = nn.Identity()
660-
self.head = nn.Identity()
659+
self.reset_classifier(0, '')
661660
return take_indices
662661

663662
def forward_features(self, x):
@@ -718,7 +717,7 @@ def checkpoint_filter_fn(
718717
# fixed embedding no need to load buffer from checkpoint
719718
continue
720719

721-
# FIXME
720+
# FIXME here while import new weights, to remove
722721
# if k == 'cls_token':
723722
# print('DEBUG: cls token -> reg')
724723
# k = 'reg_token'

0 commit comments

Comments
 (0)