@@ -520,7 +520,7 @@ def __init__(
520
520
# FIXME generalize this structure to ClassifierHead
521
521
self .norm_pre = norm_layer (self .num_features ) if head_norm_first else nn .Identity ()
522
522
self .head = nn .Sequential (OrderedDict ([
523
- ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool , flatten = True )),
523
+ ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
524
524
('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
525
525
#('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
526
526
('drop' , nn .Dropout (self .drop_rate )),
@@ -549,7 +549,7 @@ def get_classifier(self):
549
549
550
550
def reset_classifier (self , num_classes , global_pool = None ):
551
551
if global_pool is not None :
552
- self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool , flatten = True )
552
+ self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
553
553
#self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
554
554
self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
555
555
@@ -569,8 +569,8 @@ def forward_head(self, x, pre_logits: bool = False):
569
569
#x = self.head.fc(x)
570
570
#return self.head.flatten(x)
571
571
x = self .head .global_pool (x )
572
- x = self .head .norm (x )
573
- # x = self.head.flatten(x )
572
+ x = self .head .norm (x . permute ( 0 , 2 , 3 , 1 )). permute ( 0 , 3 , 1 , 2 )
573
+ x = x . squeeze ( )
574
574
x = self .head .drop (x )
575
575
return x if pre_logits else self .head .fc (x )
576
576
0 commit comments