4
4
5
5
import torch
6
6
from segment_anything .modeling import ImageEncoderViT
7
+ from torch import nn
8
+ from segment_anything .modeling .common import LayerNorm2d
7
9
8
10
from segmentation_models_pytorch .encoders ._base import EncoderMixin
9
11
@@ -16,15 +18,55 @@ def __init__(self, **kwargs):
16
18
super ().__init__ (** kwargs )
17
19
self ._out_chans = kwargs .get ("out_chans" , 256 )
18
20
self ._patch_size = kwargs .get ("patch_size" , 16 )
21
+ self ._embed_dim = kwargs .get ("embed_dim" , 768 )
19
22
self ._validate ()
23
+ self .intermediate_necks = nn .ModuleList (
24
+ [self .init_neck (self ._embed_dim , out_chan ) for out_chan in self .out_channels [:- 1 ]]
25
+ )
26
+
27
+ @staticmethod
28
+ def init_neck (embed_dim : int , out_chans : int ) -> nn .Module :
29
+ # Use similar neck as in ImageEncoderViT
30
+ return nn .Sequential (
31
+ nn .Conv2d (
32
+ embed_dim ,
33
+ out_chans ,
34
+ kernel_size = 1 ,
35
+ bias = False ,
36
+ ),
37
+ LayerNorm2d (out_chans ),
38
+ nn .Conv2d (
39
+ out_chans ,
40
+ out_chans ,
41
+ kernel_size = 3 ,
42
+ padding = 1 ,
43
+ bias = False ,
44
+ ),
45
+ LayerNorm2d (out_chans ),
46
+ )
47
+
48
+ @staticmethod
49
+ def neck_forward (neck : nn .Module , x : torch .Tensor , scale_factor : float = 1 ) -> torch .Tensor :
50
+ x = x .permute (0 , 3 , 1 , 2 )
51
+ if scale_factor != 1.0 :
52
+ x = nn .functional .interpolate (x , scale_factor = scale_factor , mode = "bilinear" )
53
+ return neck (x )
54
+
55
+ def requires_grad_ (self , requires_grad : bool = True ):
56
+ # Keep the intermediate necks trainable
57
+ for param in self .parameters ():
58
+ param .requires_grad_ (requires_grad )
59
+ for param in self .intermediate_necks .parameters ():
60
+ param .requires_grad_ (True )
61
+ return self
20
62
21
63
@property
22
64
def output_stride (self ):
23
65
return 32
24
66
25
- def _get_scale_factor ( self ) -> float :
26
- """Input image will be downscale by this factor"""
27
- return int ( math . log ( self ._patch_size , 2 ))
67
+ @ property
68
+ def out_channels ( self ):
69
+ return [ self . _out_chans // ( 2 ** i ) for i in range ( self ._encoder_depth + 1 )][:: - 1 ]
28
70
29
71
def _validate (self ):
30
72
# check vit depth
@@ -39,15 +81,30 @@ def _validate(self):
39
81
"It is recommended to set encoder depth=4 with default vit patch_size=16."
40
82
)
41
83
42
- @property
43
- def out_channels (self ):
44
- # Fill up with leading zeros to be used in Unet
45
- scale_factor = self ._get_scale_factor ()
46
- return [0 ] * scale_factor + [self ._out_chans ]
84
+ def _get_scale_factor (self ) -> float :
85
+ """Input image will be downscale by this factor"""
86
+ return int (math .log (self ._patch_size , 2 ))
47
87
48
88
def forward (self , x : torch .Tensor ) -> list [torch .Tensor ]:
49
- # Return a list of tensors to match other encoders
50
- return [x , super ().forward (x )]
89
+ x = self .patch_embed (x )
90
+ if self .pos_embed is not None :
91
+ x = x + self .pos_embed
92
+
93
+ features = []
94
+ skip_steps = self ._vit_depth // self ._encoder_depth
95
+ scale_factor = self ._get_scale_factor ()
96
+ for i , blk in enumerate (self .blocks ):
97
+ x = blk (x )
98
+ if i % skip_steps == 0 :
99
+ # Double spatial dimension and halve number of channels
100
+ neck = self .intermediate_necks [i // skip_steps ]
101
+ features .append (self .neck_forward (neck , x , scale_factor = 2 ** scale_factor ))
102
+ scale_factor -= 1
103
+
104
+ x = self .neck (x .permute (0 , 3 , 1 , 2 ))
105
+ features .append (x )
106
+
107
+ return features
51
108
52
109
def load_state_dict (self , state_dict : Mapping [str , Any ], strict : bool = True ) -> None :
53
110
# Exclude mask_decoder and prompt encoder weights
@@ -58,6 +115,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True) ->
58
115
if not k .startswith ("mask_decoder" ) and not k .startswith ("prompt_encoder" )
59
116
}
60
117
missing , unused = super ().load_state_dict (state_dict , strict = False )
118
+ missing = list (filter (lambda x : not x .startswith ("intermediate_necks" ), missing ))
61
119
if len (missing ) + len (unused ) > 0 :
62
120
n_loaded = len (state_dict ) - len (missing ) - len (unused )
63
121
warnings .warn (
0 commit comments