Skip to content

Commit 26725de

Browse files
committed
Add test for any resolution (not divisible by 32)
1 parent b2166ea commit 26725de

File tree

1 file changed

+49
-19
lines changed

1 file changed

+49
-19
lines changed

tests/models/base.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,21 @@ def decoder_channels(self):
5757
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
5858
return torch.rand(batch_size, num_channels, height, width)
5959

60+
@lru_cache
61+
def get_default_model(self):
62+
model = smp.create_model(self.model_type, self.test_encoder_name)
63+
model = model.to(default_device)
64+
return model
65+
6066
def test_forward_backward(self):
6167
sample = self._get_sample(
6268
batch_size=self.default_batch_size,
6369
num_channels=self.default_num_channels,
6470
height=self.default_height,
6571
width=self.default_width,
6672
).to(default_device)
67-
model = smp.create_model(
68-
arch=self.model_type, encoder_name=self.test_encoder_name
69-
).to(default_device)
73+
74+
model = self.get_default_model()
7075

7176
# check default in_channels=3
7277
output = model(sample)
@@ -91,14 +96,19 @@ def test_in_channels_and_depth_and_out_classes(
9196
if self.model_type in ["unet", "unetplusplus", "manet"]:
9297
kwargs = {"decoder_channels": self.decoder_channels[:depth]}
9398

94-
model = smp.create_model(
95-
arch=self.model_type,
96-
encoder_name=self.test_encoder_name,
97-
encoder_depth=depth,
98-
in_channels=in_channels,
99-
classes=classes,
100-
**kwargs,
101-
).to(default_device)
99+
model = (
100+
smp.create_model(
101+
arch=self.model_type,
102+
encoder_name=self.test_encoder_name,
103+
encoder_depth=depth,
104+
in_channels=in_channels,
105+
classes=classes,
106+
**kwargs,
107+
)
108+
.to(default_device)
109+
.eval()
110+
)
111+
102112
sample = self._get_sample(
103113
batch_size=self.default_batch_size,
104114
num_channels=in_channels,
@@ -107,7 +117,7 @@ def test_in_channels_and_depth_and_out_classes(
107117
).to(default_device)
108118

109119
# check in channels correctly set
110-
with torch.no_grad():
120+
with torch.inference_mode():
111121
output = model(sample)
112122

113123
self.assertEqual(output.shape[1], classes)
@@ -122,7 +132,8 @@ def test_classification_head(self):
122132
"dropout": 0.5,
123133
"activation": "sigmoid",
124134
},
125-
).to(default_device)
135+
)
136+
model = model.to(default_device).eval()
126137

127138
self.assertIsNotNone(model.classification_head)
128139
self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d)
@@ -139,24 +150,43 @@ def test_classification_head(self):
139150
width=self.default_width,
140151
).to(default_device)
141152

142-
with torch.no_grad():
153+
with torch.inference_mode():
143154
_, cls_probs = model(sample)
144155

145156
self.assertEqual(cls_probs.shape[1], 10)
146157

158+
def test_any_resolution(self):
159+
model = self.get_default_model()
160+
if model.requires_divisible_input_shape:
161+
self.skipTest("Model requires divisible input shape")
162+
163+
sample = self._get_sample(
164+
batch_size=self.default_batch_size,
165+
num_channels=self.default_num_channels,
166+
height=self.default_height + 3,
167+
width=self.default_width + 7,
168+
).to(default_device)
169+
170+
with torch.inference_mode():
171+
output = model(sample)
172+
173+
self.assertEqual(output.shape[2], self.default_height + 3)
174+
self.assertEqual(output.shape[3], self.default_width + 7)
175+
147176
@requires_torch_greater_or_equal("2.0.1")
148177
def test_save_load_with_hub_mixin(self):
149178
# instantiate model
150-
model = smp.create_model(
151-
arch=self.model_type, encoder_name=self.test_encoder_name
152-
).to(default_device)
179+
model = self.get_default_model()
180+
model.eval()
153181

154182
# save model
155183
with tempfile.TemporaryDirectory() as tmpdir:
156184
model.save_pretrained(
157185
tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99}
158186
)
159187
restored_model = smp.from_pretrained(tmpdir).to(default_device)
188+
restored_model.eval()
189+
160190
with open(os.path.join(tmpdir, "README.md"), "r") as f:
161191
readme = f.read()
162192

@@ -168,7 +198,7 @@ def test_save_load_with_hub_mixin(self):
168198
width=self.default_width,
169199
).to(default_device)
170200

171-
with torch.no_grad():
201+
with torch.inference_mode():
172202
output = model(sample)
173203
restored_output = restored_model(sample)
174204

@@ -197,7 +227,7 @@ def test_preserve_forward_output(self):
197227
output_tensor = torch.load(output_tensor_path, weights_only=True)
198228
output_tensor = output_tensor.to(default_device)
199229

200-
with torch.no_grad():
230+
with torch.inference_mode():
201231
output = model(input_tensor)
202232

203233
self.assertEqual(output.shape, output_tensor.shape)

0 commit comments

Comments
 (0)