-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathmodel.py
213 lines (186 loc) · 8.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import logging
from typing import Optional, Union, List, Tuple
import torch
from segment_anything.modeling import MaskDecoder, TwoWayTransformer, PromptEncoder
from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
from segmentation_models_pytorch.base import (
SegmentationModel,
SegmentationHead,
)
from segmentation_models_pytorch.encoders import get_encoder, sam_vit_encoders, get_pretrained_settings
logger = logging.getLogger("sam")
logger.setLevel(logging.WARNING)
stream = logging.StreamHandler()
logger.addHandler(stream)
logger.propagate = False
class SAM(SegmentationModel):
"""SAM_ (Segment Anything Model) is a visual transformer based encoder-decoder segmentation
model that can be used to produce high quality segmentation masks from images and prompts.
Consists of *image encoder*, *prompt encoder* and *mask decoder*. *Segmentation head* is
added after the *mask decoder* to define the final number of classes for the output mask.
Args:
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
to extract features of different spatial resolution
encoder_depth: A number of stages used in encoder in range [6, 24]. Each stage generate features
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
Default is 5
encoder_weights: One of **None** (random initialization), **"sa-1b"** (pre-training on SA-1B dataset).
decoder_channels: How many output channels image encoder will have. Default is 256.
in_channels: A number of input channels for the model, default is 3 (RGB images)
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
activation: An activation function to apply after the final convolution layer.
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
**callable** and **None**.
Default is **None**
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
on top of encoder if **aux_params** is not **None** (default). Supported params:
- classes (int): A number of classes
- pooling (str): One of "max", "avg". Default is "avg"
- dropout (float): Dropout factor in [0, 1)
- activation (str): An activation function to apply "sigmoid"/"softmax"
(could be **None** to return logits)
Returns:
``torch.nn.Module``: SAM
.. _SAM:
https://github.com/facebookresearch/segment-anything
"""
def __init__(
self,
encoder_name: str = "sam-vit_h",
encoder_depth: int = None,
encoder_weights: Optional[str] = None,
decoder_channels: List[int] = 256,
decoder_multimask_output: bool = True,
in_channels: int = 3,
image_size: int = 1024,
vit_patch_size: int = 16,
classes: int = 1,
weights: Optional[str] = "sa-1b",
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None,
):
super().__init__()
self.register_buffer("pixel_mean", torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1), False)
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
img_size=image_size,
patch_size=vit_patch_size,
out_chans=decoder_channels,
)
# this params are used instead of prompt_encoder
image_embedding_size = image_size // vit_patch_size
self.embed_dim = decoder_channels
self.image_embedding_size = (image_embedding_size, image_embedding_size)
self.pe_layer = PositionEmbeddingRandom(decoder_channels // 2)
self.no_mask_embed = nn.Embedding(1, decoder_channels)
self.decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=decoder_channels,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=decoder_channels,
iou_head_depth=3,
iou_head_hidden_dim=256,
)
self._decoder_multiclass_output = decoder_multimask_output
if weights is not None:
self._load_pretrained_weights(encoder_name, weights)
self.segmentation_head = SegmentationHead(
in_channels=3 if decoder_multimask_output else 1,
out_channels=classes,
activation=activation,
kernel_size=3,
)
if aux_params is not None:
raise NotImplementedError("Auxiliary output is not supported yet")
self.classification_head = None
self.name = encoder_name
self.initialize()
def _load_pretrained_weights(self, encoder_name: str, weights: str):
settings = get_pretrained_settings(sam_vit_encoders, encoder_name, weights)
state_dict = model_zoo.load_url(settings["url"])
state_dict = {k.replace("image_encoder", "encoder"): v for k, v in state_dict.items()}
state_dict = {k.replace("mask_decoder", "decoder"): v for k, v in state_dict.items()}
missing, unused = self.load_state_dict(state_dict, strict=False)
if len(missing) > 0 or len(unused) > 0:
n_loaded = len(state_dict) - len(missing) - len(unused)
logger.warning(
f"Only {n_loaded} out of pretrained {len(state_dict)} SAM modules are loaded. "
f"Missing modules: {missing}. Unused modules: {unused}."
)
def preprocess(self, x):
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.encoder.img_size - h
padw = self.encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.encoder.img_size, self.encoder.img_size),
mode="bilinear",
align_corners=False,
)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
def forward(self, x):
img_size = x.shape[-2:]
x = torch.stack([self.preprocess(img) for img in x])
features = self.encoder(x)
# sparse_embeddings, dense_embeddings = self.prompt_encoder(points=None, boxes=None, masks=None)
sparse_embeddings, dense_embeddings = self._get_dummy_promp_encoder_output(x.size(0))
low_res_masks, iou_predictions = self.decoder(
image_embeddings=features,
image_pe=self._get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=self._decoder_multiclass_output,
)
masks = self.postprocess_masks(low_res_masks, input_size=img_size, original_size=img_size)
# use scaling below in order to make it work with torch DDP
masks = masks * iou_predictions.view(-1, masks.size(1), 1, 1)
output = self.segmentation_head(masks)
return output
def _get_dummy_promp_encoder_output(self, bs):
"""Use this dummy output as we're training without prompts."""
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.no_mask_embed.weight.device)
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
def _get_dense_pe(self):
return self.pe_layer(self.image_embedding_size).unsqueeze(0)