8
8
9
9
@pytest .mark .parametrize ("encoder_name" , ["sam-vit_b" , "sam-vit_l" ])
10
10
@pytest .mark .parametrize ("img_size" , [64 , 128 ])
11
- @pytest .mark .parametrize ("patch_size" , [8 , 16 ])
12
- @pytest .mark .parametrize ("depth" , [6 , 24 , None ])
13
- def test_sam_encoder (encoder_name , img_size , patch_size , depth ):
14
- encoder = get_encoder (encoder_name , img_size = img_size , patch_size = patch_size , depth = depth )
15
- assert encoder ._name == encoder_name [4 :]
11
+ @pytest .mark .parametrize ("patch_size,depth" , [(8 , 3 ), (16 , 4 )])
12
+ @pytest .mark .parametrize ("vit_depth" , [12 , 24 ])
13
+ def test_sam_encoder (encoder_name , img_size , patch_size , depth , vit_depth ):
14
+ encoder = get_encoder (encoder_name , img_size = img_size , patch_size = patch_size , depth = depth , vit_depth = vit_depth )
16
15
assert encoder .output_stride == 32
17
16
18
17
sample = torch .ones (1 , 3 , img_size , img_size )
@@ -23,6 +22,13 @@ def test_sam_encoder(encoder_name, img_size, patch_size, depth):
23
22
assert out [- 1 ].size () == torch .Size ([1 , 256 , expected_patches , expected_patches ])
24
23
25
24
25
+ def test_sam_encoder_validation_error ():
26
+ with pytest .raises (ValueError ):
27
+ get_encoder ("sam-vit_b" , img_size = 64 , patch_size = 16 , depth = 5 , vit_depth = 12 )
28
+ get_encoder ("sam-vit_b" , img_size = 64 , patch_size = 16 , depth = 4 , vit_depth = None )
29
+ get_encoder ("sam-vit_b" , img_size = 64 , patch_size = 16 , depth = 4 , vit_depth = 6 )
30
+
31
+
26
32
@pytest .mark .skip (reason = "Decoder has been removed, keeping this for future integration" )
27
33
@pytest .mark .parametrize ("decoder_multiclass_output" , [True , False ])
28
34
@pytest .mark .parametrize ("n_classes" , [1 , 3 ])
@@ -43,14 +49,13 @@ def test_sam(decoder_multiclass_output, n_classes):
43
49
44
50
45
51
@pytest .mark .parametrize ("model_class" , [smp .Unet ])
46
- @pytest .mark .parametrize ("decoder_channels,patch_size " , [([64 , 32 , 16 , 8 ], 16 ), ([64 , 32 , 16 ] , 8 )])
47
- def test_sam_as_encoder_only (model_class , decoder_channels , patch_size ):
48
- img_size = 64
52
+ @pytest .mark .parametrize ("decoder_channels,encoder_depth " , [([64 , 32 , 16 , 8 ], 4 ), ([64 , 32 , 16 , 8 ], 4 )])
53
+ def test_sam_encoder_arch (model_class , decoder_channels , encoder_depth ):
54
+ img_size = 1024
49
55
model = model_class (
50
56
"sam-vit_b" ,
51
57
encoder_weights = None ,
52
- encoder_depth = 3 ,
53
- encoder_kwargs = dict (img_size = img_size , out_chans = decoder_channels [0 ], patch_size = patch_size ),
58
+ encoder_depth = encoder_depth ,
54
59
decoder_channels = decoder_channels ,
55
60
)
56
61
smp = torch .ones (1 , 3 , img_size , img_size )
@@ -65,5 +70,5 @@ def test_sam_weights():
65
70
@pytest .mark .skip (reason = "Run this test manually as it needs to download weights" )
66
71
def test_sam_encoder_weights ():
67
72
smp .create_model (
68
- "unet" , encoder_name = "sam-vit_b" , encoder_weights = "sa-1b" , encoder_depth = 12 , decoder_channels = [64 , 32 , 16 , 8 ]
73
+ "unet" , encoder_name = "sam-vit_b" , encoder_depth = 4 , encoder_weights = "sa-1b" , decoder_channels = [64 , 32 , 16 , 8 ]
69
74
)
0 commit comments