3
3
4
4
import torch
5
5
from segment_anything .modeling import MaskDecoder , TwoWayTransformer , PromptEncoder
6
+ from segment_anything .modeling .prompt_encoder import PositionEmbeddingRandom
7
+ from torch import nn
6
8
from torch .nn import functional as F
7
9
from torch .utils import model_zoo
8
10
@@ -86,16 +88,12 @@ def __init__(
86
88
out_chans = decoder_channels ,
87
89
)
88
90
91
+ # this params are used instead of prompt_encoder
89
92
image_embedding_size = image_size // vit_patch_size
90
- self .prompt_encoder = PromptEncoder (
91
- embed_dim = decoder_channels ,
92
- image_embedding_size = (image_embedding_size , image_embedding_size ),
93
- input_image_size = (image_size , image_size ),
94
- mask_in_chans = 16 ,
95
- )
96
- self .prompt_encoder .point_embeddings = None
97
- self .prompt_encoder .mask_downscaling = None
98
- self .not_a_point_embed = None
93
+ self .embed_dim = decoder_channels
94
+ self .image_embedding_size = (image_embedding_size , image_embedding_size )
95
+ self .pe_layer = PositionEmbeddingRandom (decoder_channels // 2 )
96
+ self .no_mask_embed = nn .Embedding (1 , decoder_channels )
99
97
100
98
self .decoder = MaskDecoder (
101
99
num_multimask_outputs = 3 ,
@@ -188,10 +186,11 @@ def forward(self, x):
188
186
img_size = x .shape [- 2 :]
189
187
x = torch .stack ([self .preprocess (img ) for img in x ])
190
188
features = self .encoder (x )
191
- sparse_embeddings , dense_embeddings = self .prompt_encoder (points = None , boxes = None , masks = None )
189
+ # sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
190
+ sparse_embeddings , dense_embeddings = self ._get_dummy_promp_encoder_output (x .size (0 ))
192
191
low_res_masks , iou_predictions = self .decoder (
193
192
image_embeddings = features ,
194
- image_pe = self .prompt_encoder . get_dense_pe (),
193
+ image_pe = self ._get_dense_pe (),
195
194
sparse_prompt_embeddings = sparse_embeddings ,
196
195
dense_prompt_embeddings = dense_embeddings ,
197
196
multimask_output = self ._decoder_multiclass_output ,
@@ -201,3 +200,14 @@ def forward(self, x):
201
200
masks = masks * iou_predictions .view (- 1 , masks .size (1 ), 1 , 1 )
202
201
output = self .segmentation_head (masks )
203
202
return output
203
+
204
+ def _get_dummy_promp_encoder_output (self , bs ):
205
+ """Use this dummy output as we're training without prompts."""
206
+ sparse_embeddings = torch .empty ((bs , 0 , self .embed_dim ), device = self .no_mask_embed .weight .device )
207
+ dense_embeddings = self .no_mask_embed .weight .reshape (1 , - 1 , 1 , 1 ).expand (
208
+ bs , - 1 , self .image_embedding_size [0 ], self .image_embedding_size [1 ]
209
+ )
210
+ return sparse_embeddings , dense_embeddings
211
+
212
+ def _get_dense_pe (self ):
213
+ return self .pe_layer (self .image_embedding_size ).unsqueeze (0 )
0 commit comments