@@ -518,6 +518,7 @@ def __init__(
518
518
self .head = nn .Sequential (OrderedDict ([
519
519
('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
520
520
('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
521
+ ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
521
522
('drop' , nn .Dropout (self .drop_rate )),
522
523
('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())]))
523
524
@@ -545,6 +546,7 @@ def get_classifier(self):
545
546
def reset_classifier (self , num_classes , global_pool = None ):
546
547
if global_pool is not None :
547
548
self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
549
+ self .head .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
548
550
self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
549
551
550
552
def forward_features (self , x ):
@@ -559,7 +561,7 @@ def forward_features(self, x):
559
561
def forward_head (self , x , pre_logits : bool = False ):
560
562
x = self .head .global_pool (x )
561
563
x = self .head .norm (x .permute (0 , 2 , 3 , 1 )).permute (0 , 3 , 1 , 2 )
562
- x = x . squeeze ( )
564
+ x = self . head . flatten ( x )
563
565
x = self .head .drop (x )
564
566
return x if pre_logits else self .head .fc (x )
565
567
0 commit comments