2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
4
5
+ from typing import List , Literal
6
+
5
7
6
8
class Conv3x3GNReLU (nn .Module ):
7
- def __init__ (self , in_channels , out_channels , upsample = False ):
9
+ def __init__ (self , in_channels : int , out_channels : int , upsample : bool = False ):
8
10
super ().__init__ ()
9
11
self .upsample = upsample
10
12
self .block = nn .Sequential (
@@ -15,27 +17,27 @@ def __init__(self, in_channels, out_channels, upsample=False):
15
17
nn .ReLU (inplace = True ),
16
18
)
17
19
18
- def forward (self , x ) :
20
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
19
21
x = self .block (x )
20
22
if self .upsample :
21
- x = F .interpolate (x , scale_factor = 2 , mode = "bilinear" , align_corners = True )
23
+ x = F .interpolate (x , scale_factor = 2.0 , mode = "bilinear" , align_corners = True )
22
24
return x
23
25
24
26
25
27
class FPNBlock (nn .Module ):
26
- def __init__ (self , pyramid_channels , skip_channels ):
28
+ def __init__ (self , pyramid_channels : int , skip_channels : int ):
27
29
super ().__init__ ()
28
30
self .skip_conv = nn .Conv2d (skip_channels , pyramid_channels , kernel_size = 1 )
29
31
30
- def forward (self , x , skip = None ) :
31
- x = F .interpolate (x , scale_factor = 2 , mode = "nearest" )
32
+ def forward (self , x : torch . Tensor , skip : torch . Tensor ) -> torch . Tensor :
33
+ x = F .interpolate (x , scale_factor = 2.0 , mode = "nearest" )
32
34
skip = self .skip_conv (skip )
33
35
x = x + skip
34
36
return x
35
37
36
38
37
39
class SegmentationBlock (nn .Module ):
38
- def __init__ (self , in_channels , out_channels , n_upsamples = 0 ):
40
+ def __init__ (self , in_channels : int , out_channels : int , n_upsamples : int = 0 ):
39
41
super ().__init__ ()
40
42
41
43
blocks = [Conv3x3GNReLU (in_channels , out_channels , upsample = bool (n_upsamples ))]
@@ -51,36 +53,37 @@ def forward(self, x):
51
53
52
54
53
55
class MergeBlock (nn .Module ):
54
- def __init__ (self , policy ):
56
+ def __init__ (self , policy : Literal [ "add" , "cat" ] ):
55
57
super ().__init__ ()
56
58
if policy not in ["add" , "cat" ]:
57
59
raise ValueError (
58
60
"`merge_policy` must be one of: ['add', 'cat'], got {}" .format (policy )
59
61
)
60
62
self .policy = policy
61
63
62
- def forward (self , x ) :
64
+ def forward (self , x : List [ torch . Tensor ]) -> torch . Tensor :
63
65
if self .policy == "add" :
64
- return sum ( x )
66
+ output = torch . stack ( x ). sum ( dim = 0 )
65
67
elif self .policy == "cat" :
66
- return torch .cat (x , dim = 1 )
68
+ output = torch .cat (x , dim = 1 )
67
69
else :
68
70
raise ValueError (
69
71
"`merge_policy` must be one of: ['add', 'cat'], got {}" .format (
70
72
self .policy
71
73
)
72
74
)
75
+ return output
73
76
74
77
75
78
class FPNDecoder (nn .Module ):
76
79
def __init__ (
77
80
self ,
78
- encoder_channels ,
79
- encoder_depth = 5 ,
80
- pyramid_channels = 256 ,
81
- segmentation_channels = 128 ,
82
- dropout = 0.2 ,
83
- merge_policy = "add" ,
81
+ encoder_channels : List [ int ] ,
82
+ encoder_depth : int = 5 ,
83
+ pyramid_channels : int = 256 ,
84
+ segmentation_channels : int = 128 ,
85
+ dropout : float = 0.2 ,
86
+ merge_policy : Literal [ "add" , "cat" ] = "add" ,
84
87
):
85
88
super ().__init__ ()
86
89
@@ -116,17 +119,20 @@ def __init__(
116
119
self .merge = MergeBlock (merge_policy )
117
120
self .dropout = nn .Dropout2d (p = dropout , inplace = True )
118
121
119
- def forward (self , * features ) :
122
+ def forward (self , features : List [ torch . Tensor ]) -> torch . Tensor :
120
123
c2 , c3 , c4 , c5 = features [- 4 :]
121
124
122
125
p5 = self .p5 (c5 )
123
126
p4 = self .p4 (p5 , c4 )
124
127
p3 = self .p3 (p4 , c3 )
125
128
p2 = self .p2 (p3 , c2 )
126
129
127
- feature_pyramid = [
128
- seg_block (p ) for seg_block , p in zip (self .seg_blocks , [p5 , p4 , p3 , p2 ])
129
- ]
130
+ s5 = self .seg_blocks [0 ](p5 )
131
+ s4 = self .seg_blocks [1 ](p4 )
132
+ s3 = self .seg_blocks [2 ](p3 )
133
+ s2 = self .seg_blocks [3 ](p2 )
134
+
135
+ feature_pyramid = [s5 , s4 , s3 , s2 ]
130
136
x = self .merge (feature_pyramid )
131
137
x = self .dropout (x )
132
138
0 commit comments