File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change
1
+ import segmentation_models_pytorch as smp
2
+
1
3
from tests .models import base
2
4
3
5
4
6
class TestUnetPlusPlusModel (base .BaseModelTester ):
5
7
test_model_type = "unetplusplus"
6
8
files_for_diff = [r"decoders/unetplusplus/" , r"base/" ]
9
+
10
+ def test_interpolation (self ):
11
+ # test bilinear
12
+ model_1 = smp .create_model (
13
+ self .test_model_type ,
14
+ self .test_encoder_name ,
15
+ decoder_interpolation = "bilinear" ,
16
+ )
17
+ is_tested = False
18
+ for module in model_1 .decoder .modules ():
19
+ if module .__class__ .__name__ == "DecoderBlock" :
20
+ assert module .interpolation_mode == "bilinear"
21
+ is_tested = True
22
+ assert is_tested
23
+
24
+ # test bicubic
25
+ model_2 = smp .create_model (
26
+ self .test_model_type ,
27
+ self .test_encoder_name ,
28
+ decoder_interpolation = "bicubic" ,
29
+ )
30
+ is_tested = False
31
+ for module in model_2 .decoder .modules ():
32
+ if module .__class__ .__name__ == "DecoderBlock" :
33
+ assert module .interpolation_mode == "bicubic"
34
+ is_tested = True
35
+ assert is_tested
You can’t perform that action at this time.
0 commit comments