-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathbase.py
208 lines (178 loc) · 7.8 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import unittest
import torch
import segmentation_models_pytorch as smp
from functools import lru_cache
from tests.utils import default_device
class BaseEncoderTester(unittest.TestCase):
encoder_names = []
# standard encoder configuration
num_output_features = 6
output_strides = [1, 2, 4, 8, 16, 32]
supports_dilated = True
# test sample configuration
default_batch_size = 1
default_num_channels = 3
default_height = 64
default_width = 64
# test configurations
in_channels_to_test = [1, 3, 4]
depth_to_test = [3, 4, 5]
strides_to_test = [8, 16] # 32 is a default one
@lru_cache
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
return torch.rand(batch_size, num_channels, height, width)
def get_features_output_strides(self, sample, features):
height, width = sample.shape[2:]
height_strides = [height // f.shape[2] for f in features]
width_strides = [width // f.shape[3] for f in features]
return height_strides, width_strides
def test_forward_backward(self):
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
for encoder_name in self.encoder_names:
with self.subTest(encoder_name=encoder_name):
# init encoder
encoder = smp.encoders.get_encoder(
encoder_name, in_channels=3, encoder_weights=None
).to(default_device)
# forward
features = encoder.forward(sample)
self.assertEqual(
len(features),
self.num_output_features,
f"Encoder `{encoder_name}` should have {self.num_output_features} output feature maps, but has {len(features)}",
)
# backward
features[-1].mean().backward()
def test_in_channels(self):
cases = [
(encoder_name, in_channels)
for encoder_name in self.encoder_names
for in_channels in self.in_channels_to_test
]
for encoder_name, in_channels in cases:
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=in_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
with self.subTest(encoder_name=encoder_name, in_channels=in_channels):
encoder = smp.encoders.get_encoder(
encoder_name, in_channels=in_channels, encoder_weights=None
).to(default_device)
encoder.eval()
# forward
with torch.inference_mode():
encoder.forward(sample)
def test_depth(self):
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
cases = [
(encoder_name, depth)
for encoder_name in self.encoder_names
for depth in self.depth_to_test
]
for encoder_name, depth in cases:
with self.subTest(encoder_name=encoder_name, depth=depth):
encoder = smp.encoders.get_encoder(
encoder_name,
in_channels=self.default_num_channels,
encoder_weights=None,
depth=depth,
).to(default_device)
encoder.eval()
# forward
with torch.inference_mode():
features = encoder.forward(sample)
# check number of features
self.assertEqual(
len(features),
depth + 1,
f"Encoder `{encoder_name}` should have {depth + 1} output feature maps, but has {len(features)}",
)
# check feature strides
height_strides, width_strides = self.get_features_output_strides(
sample, features
)
self.assertEqual(
height_strides,
self.output_strides[: depth + 1],
f"Encoder `{encoder_name}` should have output strides {self.output_strides[: depth + 1]}, but has {height_strides}",
)
self.assertEqual(
width_strides,
self.output_strides[: depth + 1],
f"Encoder `{encoder_name}` should have output strides {self.output_strides[: depth + 1]}, but has {width_strides}",
)
# check encoder output stride property
self.assertEqual(
encoder.output_stride,
self.output_strides[depth],
f"Encoder `{encoder_name}` last feature map should have output stride {self.output_strides[depth]}, but has {encoder.output_stride}",
)
# check out channels also have proper length
self.assertEqual(
len(encoder.out_channels),
depth + 1,
f"Encoder `{encoder_name}` should have {depth + 1} out_channels, but has {len(encoder.out_channels)}",
)
def test_dilated(self):
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
cases = [
(encoder_name, stride)
for encoder_name in self.encoder_names
for stride in self.strides_to_test
]
# special case for encoders that do not support dilated model
# just check proper error is raised
if not self.supports_dilated:
with self.assertRaises(ValueError, msg="not support dilated mode"):
encoder_name, stride = cases[0]
encoder = smp.encoders.get_encoder(
encoder_name,
in_channels=self.default_num_channels,
encoder_weights=None,
output_stride=stride,
).to(default_device)
return
for encoder_name, stride in cases:
with self.subTest(encoder_name=encoder_name, stride=stride):
encoder = smp.encoders.get_encoder(
encoder_name,
in_channels=self.default_num_channels,
encoder_weights=None,
output_stride=stride,
).to(default_device)
encoder.eval()
# forward
with torch.inference_mode():
features = encoder.forward(sample)
height_strides, width_strides = self.get_features_output_strides(
sample, features
)
expected_height_strides = [min(stride, s) for s in height_strides]
expected_width_strides = [min(stride, s) for s in width_strides]
self.assertEqual(
height_strides,
expected_height_strides,
f"Encoder `{encoder_name}` should have height output strides {expected_height_strides}, but has {height_strides}",
)
self.assertEqual(
width_strides,
expected_width_strides,
f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}",
)