-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathmodel.py
88 lines (67 loc) · 2.73 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from typing import TypeVar, Type
from . import initialization as init
from .hub_mixin import SMPHubMixin
from .utils import is_torch_compiling
T = TypeVar("T", bound="SegmentationModel")
class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""
# if model supports shape not divisible by 2 ^ n set to False
requires_divisible_input_shape = True
# Fix type-hint for models, to avoid HubMixin signature
def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance
def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
if self.classification_head is not None:
init.initialize_head(self.classification_head)
def check_input_shape(self, x):
"""Check if the input shape is divisible by the output stride.
If not, raise a RuntimeError.
"""
if not self.requires_divisible_input_shape:
return
h, w = x.shape[-2:]
output_stride = self.encoder.output_stride
if h % output_stride != 0 or w % output_stride != 0:
new_h = (
(h // output_stride + 1) * output_stride
if h % output_stride != 0
else h
)
new_w = (
(w // output_stride + 1) * output_stride
if w % output_stride != 0
else w
)
raise RuntimeError(
f"Wrong input shape height={h}, width={w}. Expected image height and width "
f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
)
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
if not (
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
):
self.check_input_shape(x)
features = self.encoder(x)
decoder_output = self.decoder(features)
masks = self.segmentation_head(decoder_output)
if self.classification_head is not None:
labels = self.classification_head(features[-1])
return masks, labels
return masks
@torch.no_grad()
def predict(self, x):
"""Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
Args:
x: 4D torch tensor with shape (batch_size, channels, height, width)
Return:
prediction: 4D torch tensor with shape (batch_size, classes, height, width)
"""
if self.training:
self.eval()
x = self.forward(x)
return x