Skip to content

Commit c5830cc

Browse files
committed
Update davit.py
1 parent d71fbc4 commit c5830cc

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

timm/models/davit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def __init__(
518518
self.head = nn.Sequential(OrderedDict([
519519
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
520520
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
521+
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
521522
('drop', nn.Dropout(self.drop_rate)),
522523
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
523524

@@ -545,6 +546,7 @@ def get_classifier(self):
545546
def reset_classifier(self, num_classes, global_pool=None):
546547
if global_pool is not None:
547548
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
549+
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
548550
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
549551

550552
def forward_features(self, x):
@@ -559,7 +561,7 @@ def forward_features(self, x):
559561
def forward_head(self, x, pre_logits: bool = False):
560562
x = self.head.global_pool(x)
561563
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
562-
x = x.squeeze()
564+
x = self.head.flatten(x)
563565
x = self.head.drop(x)
564566
return x if pre_logits else self.head.fc(x)
565567

0 commit comments

Comments
 (0)