3
3
4
4
import torch
5
5
from segment_anything .modeling import MaskDecoder , TwoWayTransformer , PromptEncoder
6
+ from segment_anything .modeling .prompt_encoder import PositionEmbeddingRandom
7
+ from torch import nn
6
8
from torch .nn import functional as F
7
9
from torch .utils import model_zoo
8
10
@@ -86,13 +88,12 @@ def __init__(
86
88
out_chans = decoder_channels ,
87
89
)
88
90
91
+ # this params are used instead of prompt_encoder
89
92
image_embedding_size = image_size // vit_patch_size
90
- self .prompt_encoder = PromptEncoder (
91
- embed_dim = decoder_channels ,
92
- image_embedding_size = (image_embedding_size , image_embedding_size ),
93
- input_image_size = (image_size , image_size ),
94
- mask_in_chans = 16 ,
95
- )
93
+ self .embed_dim = decoder_channels
94
+ self .image_embedding_size = (image_embedding_size , image_embedding_size )
95
+ self .pe_layer = PositionEmbeddingRandom (decoder_channels // 2 )
96
+ self .no_mask_embed = nn .Embedding (1 , decoder_channels )
96
97
97
98
self .decoder = MaskDecoder (
98
99
num_multimask_outputs = 3 ,
@@ -185,10 +186,11 @@ def forward(self, x):
185
186
img_size = x .shape [- 2 :]
186
187
x = torch .stack ([self .preprocess (img ) for img in x ])
187
188
features = self .encoder (x )
188
- sparse_embeddings , dense_embeddings = self .prompt_encoder (points = None , boxes = None , masks = None )
189
+ # sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
190
+ sparse_embeddings , dense_embeddings = self ._get_dummy_promp_encoder_output (x .size (0 ))
189
191
low_res_masks , iou_predictions = self .decoder (
190
192
image_embeddings = features ,
191
- image_pe = self .prompt_encoder . get_dense_pe (),
193
+ image_pe = self ._get_dense_pe (),
192
194
sparse_prompt_embeddings = sparse_embeddings ,
193
195
dense_prompt_embeddings = dense_embeddings ,
194
196
multimask_output = self ._decoder_multiclass_output ,
@@ -198,3 +200,14 @@ def forward(self, x):
198
200
masks = masks * iou_predictions .view (- 1 , masks .size (1 ), 1 , 1 )
199
201
output = self .segmentation_head (masks )
200
202
return output
203
+
204
+ def _get_dummy_promp_encoder_output (self , bs ):
205
+ """Use this dummy output as we're training without prompts."""
206
+ sparse_embeddings = torch .empty ((bs , 0 , self .embed_dim ), device = self .no_mask_embed .weight .device )
207
+ dense_embeddings = self .no_mask_embed .weight .reshape (1 , - 1 , 1 , 1 ).expand (
208
+ bs , - 1 , self .image_embedding_size [0 ], self .image_embedding_size [1 ]
209
+ )
210
+ return sparse_embeddings , dense_embeddings
211
+
212
+ def _get_dense_pe (self ):
213
+ return self .pe_layer (self .image_embedding_size ).unsqueeze (0 )
0 commit comments