Skip to content

Commit 61042ac

Browse files
committed
Add Unet++ test
1 parent 9a5cc20 commit 61042ac

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/models/test_unetplusplus.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,35 @@
1+
import segmentation_models_pytorch as smp
2+
13
from tests.models import base
24

35

46
class TestUnetPlusPlusModel(base.BaseModelTester):
57
test_model_type = "unetplusplus"
68
files_for_diff = [r"decoders/unetplusplus/", r"base/"]
9+
10+
def test_interpolation(self):
11+
# test bilinear
12+
model_1 = smp.create_model(
13+
self.test_model_type,
14+
self.test_encoder_name,
15+
decoder_interpolation="bilinear",
16+
)
17+
is_tested = False
18+
for module in model_1.decoder.modules():
19+
if module.__class__.__name__ == "DecoderBlock":
20+
assert module.interpolation_mode == "bilinear"
21+
is_tested = True
22+
assert is_tested
23+
24+
# test bicubic
25+
model_2 = smp.create_model(
26+
self.test_model_type,
27+
self.test_encoder_name,
28+
decoder_interpolation="bicubic",
29+
)
30+
is_tested = False
31+
for module in model_2.decoder.modules():
32+
if module.__class__.__name__ == "DecoderBlock":
33+
assert module.interpolation_mode == "bicubic"
34+
is_tested = True
35+
assert is_tested

0 commit comments

Comments
 (0)