2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
5
+ from typing import Optional , Sequence
5
6
from segmentation_models_pytorch .base import modules as md
6
7
7
8
8
- class DecoderBlock (nn .Module ):
9
+ class UnetDecoderBlock (nn .Module ):
10
+ """A decoder block in the U-Net architecture that performs upsampling and feature fusion."""
11
+
9
12
def __init__ (
10
13
self ,
11
- in_channels ,
12
- skip_channels ,
13
- out_channels ,
14
- use_batchnorm = True ,
15
- attention_type = None ,
14
+ in_channels : int ,
15
+ skip_channels : int ,
16
+ out_channels : int ,
17
+ use_batchnorm : bool = True ,
18
+ attention_type : Optional [str ] = None ,
19
+ interpolation_mode : str = "nearest" ,
16
20
):
17
21
super ().__init__ ()
22
+ self .interpolation_mode = interpolation_mode
18
23
self .conv1 = md .Conv2dReLU (
19
24
in_channels + skip_channels ,
20
25
out_channels ,
@@ -34,19 +39,31 @@ def __init__(
34
39
)
35
40
self .attention2 = md .Attention (attention_type , in_channels = out_channels )
36
41
37
- def forward (self , x , skip = None ):
38
- x = F .interpolate (x , scale_factor = 2 , mode = "nearest" )
39
- if skip is not None :
40
- x = torch .cat ([x , skip ], dim = 1 )
41
- x = self .attention1 (x )
42
- x = self .conv1 (x )
43
- x = self .conv2 (x )
44
- x = self .attention2 (x )
45
- return x
42
+ def forward (
43
+ self ,
44
+ feature_map : torch .Tensor ,
45
+ target_height : int ,
46
+ target_width : int ,
47
+ skip_connection : Optional [torch .Tensor ] = None ,
48
+ ) -> torch .Tensor :
49
+ feature_map = F .interpolate (
50
+ feature_map ,
51
+ size = (target_height , target_width ),
52
+ mode = self .interpolation_mode ,
53
+ )
54
+ if skip_connection is not None :
55
+ feature_map = torch .cat ([feature_map , skip_connection ], dim = 1 )
56
+ feature_map = self .attention1 (feature_map )
57
+ feature_map = self .conv1 (feature_map )
58
+ feature_map = self .conv2 (feature_map )
59
+ feature_map = self .attention2 (feature_map )
60
+ return feature_map
61
+
46
62
63
+ class UnetCenterBlock (nn .Sequential ):
64
+ """Center block of the Unet decoder. Applied to the last feature map of the encoder."""
47
65
48
- class CenterBlock (nn .Sequential ):
49
- def __init__ (self , in_channels , out_channels , use_batchnorm = True ):
66
+ def __init__ (self , in_channels : int , out_channels : int , use_batchnorm : bool = True ):
50
67
conv1 = md .Conv2dReLU (
51
68
in_channels ,
52
69
out_channels ,
@@ -65,14 +82,21 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True):
65
82
66
83
67
84
class UnetDecoder (nn .Module ):
85
+ """The decoder part of the U-Net architecture.
86
+
87
+ Takes encoded features from different stages of the encoder and progressively upsamples them while
88
+ combining with skip connections. This helps preserve fine-grained details in the final segmentation.
89
+ """
90
+
68
91
def __init__ (
69
92
self ,
70
- encoder_channels ,
71
- decoder_channels ,
72
- n_blocks = 5 ,
73
- use_batchnorm = True ,
74
- attention_type = None ,
75
- center = False ,
93
+ encoder_channels : Sequence [int ],
94
+ decoder_channels : Sequence [int ],
95
+ n_blocks : int = 5 ,
96
+ use_batchnorm : bool = True ,
97
+ attention_type : Optional [str ] = None ,
98
+ add_center_block : bool = False ,
99
+ interpolation_mode : str = "nearest" ,
76
100
):
77
101
super ().__init__ ()
78
102
@@ -94,31 +118,45 @@ def __init__(
94
118
skip_channels = list (encoder_channels [1 :]) + [0 ]
95
119
out_channels = decoder_channels
96
120
97
- if center :
98
- self .center = CenterBlock (
121
+ if add_center_block :
122
+ self .center = UnetCenterBlock (
99
123
head_channels , head_channels , use_batchnorm = use_batchnorm
100
124
)
101
125
else :
102
126
self .center = nn .Identity ()
103
127
104
128
# combine decoder keyword arguments
105
- kwargs = dict (use_batchnorm = use_batchnorm , attention_type = attention_type )
106
- blocks = [
107
- DecoderBlock (in_ch , skip_ch , out_ch , ** kwargs )
108
- for in_ch , skip_ch , out_ch in zip (in_channels , skip_channels , out_channels )
109
- ]
110
- self .blocks = nn .ModuleList (blocks )
111
-
112
- def forward (self , * features ):
129
+ self .blocks = nn .ModuleList ()
130
+ for block_in_channels , block_skip_channels , block_out_channels in zip (
131
+ in_channels , skip_channels , out_channels
132
+ ):
133
+ block = UnetDecoderBlock (
134
+ block_in_channels ,
135
+ block_skip_channels ,
136
+ block_out_channels ,
137
+ use_batchnorm = use_batchnorm ,
138
+ attention_type = attention_type ,
139
+ interpolation_mode = interpolation_mode ,
140
+ )
141
+ self .blocks .append (block )
142
+
143
+ def forward (self , * features : torch .Tensor ) -> torch .Tensor :
144
+ # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...]
145
+ spatial_shapes = [feature .shape [2 :] for feature in features ]
146
+ spatial_shapes = spatial_shapes [::- 1 ]
147
+
113
148
features = features [1 :] # remove first skip with same spatial resolution
114
149
features = features [::- 1 ] # reverse channels to start from head of encoder
115
150
116
151
head = features [0 ]
117
- skips = features [1 :]
152
+ skip_connections = features [1 :]
118
153
119
154
x = self .center (head )
155
+
120
156
for i , decoder_block in enumerate (self .blocks ):
121
- skip = skips [i ] if i < len (skips ) else None
122
- x = decoder_block (x , skip )
157
+ # upsample to the next spatial shape
158
+ height , width = spatial_shapes [i + 1 ]
159
+ skip_connection = skip_connections [i ] if i < len (skip_connections ) else None
160
+ x = decoder_block (x , height , width , skip_connection = skip_connection )
123
161
124
162
return x
0 commit comments