Skip to content

Commit 40d97a3

Browse files
committed
Add test unet
1 parent 61042ac commit 40d97a3

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/models/test_unet.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,26 @@
1+
import segmentation_models_pytorch as smp
12
from tests.models import base
23

34

45
class TestUnetModel(base.BaseModelTester):
56
test_model_type = "unet"
67
files_for_diff = [r"decoders/unet/", r"base/"]
8+
9+
def test_interpolation(self):
10+
# test bilinear
11+
model_1 = smp.create_model(
12+
self.test_model_type,
13+
self.test_encoder_name,
14+
decoder_interpolation="bilinear",
15+
)
16+
for block in model_1.decoder.blocks:
17+
assert block.interpolation_mode == "bilinear"
18+
19+
# test bicubic
20+
model_2 = smp.create_model(
21+
self.test_model_type,
22+
self.test_encoder_name,
23+
decoder_interpolation="bicubic",
24+
)
25+
for block in model_2.decoder.blocks:
26+
assert block.interpolation_mode == "bicubic"

0 commit comments

Comments
 (0)