Skip to content

Commit c3f13d0

Browse files
committed
Update davit.py
1 parent 9f79f02 commit c3f13d0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

timm/models/davit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def __init__(
520520
# FIXME generalize this structure to ClassifierHead
521521
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
522522
self.head = nn.Sequential(OrderedDict([
523-
('global_pool', SelectAdaptivePool2d(pool_type=global_pool, flatten=True)),
523+
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
524524
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
525525
#('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
526526
('drop', nn.Dropout(self.drop_rate)),
@@ -549,7 +549,7 @@ def get_classifier(self):
549549

550550
def reset_classifier(self, num_classes, global_pool=None):
551551
if global_pool is not None:
552-
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
552+
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
553553
#self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
554554
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
555555

@@ -569,8 +569,8 @@ def forward_head(self, x, pre_logits: bool = False):
569569
#x = self.head.fc(x)
570570
#return self.head.flatten(x)
571571
x = self.head.global_pool(x)
572-
x = self.head.norm(x)
573-
#x = self.head.flatten(x)
572+
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
573+
x = x.squeeze()
574574
x = self.head.drop(x)
575575
return x if pre_logits else self.head.fc(x)
576576

0 commit comments

Comments
 (0)