We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 61042ac commit 40d97a3Copy full SHA for 40d97a3
tests/models/test_unet.py
@@ -1,6 +1,26 @@
1
+import segmentation_models_pytorch as smp
2
from tests.models import base
3
4
5
class TestUnetModel(base.BaseModelTester):
6
test_model_type = "unet"
7
files_for_diff = [r"decoders/unet/", r"base/"]
8
+
9
+ def test_interpolation(self):
10
+ # test bilinear
11
+ model_1 = smp.create_model(
12
+ self.test_model_type,
13
+ self.test_encoder_name,
14
+ decoder_interpolation="bilinear",
15
+ )
16
+ for block in model_1.decoder.blocks:
17
+ assert block.interpolation_mode == "bilinear"
18
19
+ # test bicubic
20
+ model_2 = smp.create_model(
21
22
23
+ decoder_interpolation="bicubic",
24
25
+ for block in model_2.decoder.blocks:
26
+ assert block.interpolation_mode == "bicubic"
0 commit comments