Skip to content

Commit d71fbc4

Browse files
committed
Update davit.py
1 parent c3f13d0 commit d71fbc4

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

timm/models/davit.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -509,20 +509,15 @@ def __init__(
509509
stages.append(stage)
510510
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
511511

512-
513512
self.stages = nn.Sequential(*stages)
514513

515-
#self.norm = norm_layer(self.num_features)
516-
#self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
517-
518514
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
519515
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
520516
# FIXME generalize this structure to ClassifierHead
521517
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
522518
self.head = nn.Sequential(OrderedDict([
523519
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
524520
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
525-
#('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
526521
('drop', nn.Dropout(self.drop_rate)),
527522
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
528523

@@ -550,7 +545,6 @@ def get_classifier(self):
550545
def reset_classifier(self, num_classes, global_pool=None):
551546
if global_pool is not None:
552547
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
553-
#self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
554548
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
555549

556550
def forward_features(self, x):
@@ -563,11 +557,6 @@ def forward_features(self, x):
563557
return x
564558

565559
def forward_head(self, x, pre_logits: bool = False):
566-
#return self.head(x, pre_logits=pre_logits)
567-
#x = self.head.global_pool(x)
568-
#x = self.norms(x)
569-
#x = self.head.fc(x)
570-
#return self.head.flatten(x)
571560
x = self.head.global_pool(x)
572561
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
573562
x = x.squeeze()
@@ -624,7 +613,7 @@ def _cfg(url='', **kwargs):
624613
return {
625614
'url': url,
626615
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
627-
'crop_pct': 0.875, 'interpolation': 'bilinear',
616+
'crop_pct': 0.850, 'interpolation': 'bicubic',
628617
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
629618
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
630619
**kwargs

0 commit comments

Comments
 (0)