@@ -57,16 +57,21 @@ def decoder_channels(self):
57
57
def _get_sample (self , batch_size = 1 , num_channels = 3 , height = 32 , width = 32 ):
58
58
return torch .rand (batch_size , num_channels , height , width )
59
59
60
+ @lru_cache
61
+ def get_default_model (self ):
62
+ model = smp .create_model (self .model_type , self .test_encoder_name )
63
+ model = model .to (default_device )
64
+ return model
65
+
60
66
def test_forward_backward (self ):
61
67
sample = self ._get_sample (
62
68
batch_size = self .default_batch_size ,
63
69
num_channels = self .default_num_channels ,
64
70
height = self .default_height ,
65
71
width = self .default_width ,
66
72
).to (default_device )
67
- model = smp .create_model (
68
- arch = self .model_type , encoder_name = self .test_encoder_name
69
- ).to (default_device )
73
+
74
+ model = self .get_default_model ()
70
75
71
76
# check default in_channels=3
72
77
output = model (sample )
@@ -91,14 +96,19 @@ def test_in_channels_and_depth_and_out_classes(
91
96
if self .model_type in ["unet" , "unetplusplus" , "manet" ]:
92
97
kwargs = {"decoder_channels" : self .decoder_channels [:depth ]}
93
98
94
- model = smp .create_model (
95
- arch = self .model_type ,
96
- encoder_name = self .test_encoder_name ,
97
- encoder_depth = depth ,
98
- in_channels = in_channels ,
99
- classes = classes ,
100
- ** kwargs ,
101
- ).to (default_device )
99
+ model = (
100
+ smp .create_model (
101
+ arch = self .model_type ,
102
+ encoder_name = self .test_encoder_name ,
103
+ encoder_depth = depth ,
104
+ in_channels = in_channels ,
105
+ classes = classes ,
106
+ ** kwargs ,
107
+ )
108
+ .to (default_device )
109
+ .eval ()
110
+ )
111
+
102
112
sample = self ._get_sample (
103
113
batch_size = self .default_batch_size ,
104
114
num_channels = in_channels ,
@@ -107,7 +117,7 @@ def test_in_channels_and_depth_and_out_classes(
107
117
).to (default_device )
108
118
109
119
# check in channels correctly set
110
- with torch .no_grad ():
120
+ with torch .inference_mode ():
111
121
output = model (sample )
112
122
113
123
self .assertEqual (output .shape [1 ], classes )
@@ -122,7 +132,8 @@ def test_classification_head(self):
122
132
"dropout" : 0.5 ,
123
133
"activation" : "sigmoid" ,
124
134
},
125
- ).to (default_device )
135
+ )
136
+ model = model .to (default_device ).eval ()
126
137
127
138
self .assertIsNotNone (model .classification_head )
128
139
self .assertIsInstance (model .classification_head [0 ], torch .nn .AdaptiveAvgPool2d )
@@ -139,24 +150,43 @@ def test_classification_head(self):
139
150
width = self .default_width ,
140
151
).to (default_device )
141
152
142
- with torch .no_grad ():
153
+ with torch .inference_mode ():
143
154
_ , cls_probs = model (sample )
144
155
145
156
self .assertEqual (cls_probs .shape [1 ], 10 )
146
157
158
+ def test_any_resolution (self ):
159
+ model = self .get_default_model ()
160
+ if model .requires_divisible_input_shape :
161
+ self .skipTest ("Model requires divisible input shape" )
162
+
163
+ sample = self ._get_sample (
164
+ batch_size = self .default_batch_size ,
165
+ num_channels = self .default_num_channels ,
166
+ height = self .default_height + 3 ,
167
+ width = self .default_width + 7 ,
168
+ ).to (default_device )
169
+
170
+ with torch .inference_mode ():
171
+ output = model (sample )
172
+
173
+ self .assertEqual (output .shape [2 ], self .default_height + 3 )
174
+ self .assertEqual (output .shape [3 ], self .default_width + 7 )
175
+
147
176
@requires_torch_greater_or_equal ("2.0.1" )
148
177
def test_save_load_with_hub_mixin (self ):
149
178
# instantiate model
150
- model = smp .create_model (
151
- arch = self .model_type , encoder_name = self .test_encoder_name
152
- ).to (default_device )
179
+ model = self .get_default_model ()
180
+ model .eval ()
153
181
154
182
# save model
155
183
with tempfile .TemporaryDirectory () as tmpdir :
156
184
model .save_pretrained (
157
185
tmpdir , dataset = "test_dataset" , metrics = {"my_awesome_metric" : 0.99 }
158
186
)
159
187
restored_model = smp .from_pretrained (tmpdir ).to (default_device )
188
+ restored_model .eval ()
189
+
160
190
with open (os .path .join (tmpdir , "README.md" ), "r" ) as f :
161
191
readme = f .read ()
162
192
@@ -168,7 +198,7 @@ def test_save_load_with_hub_mixin(self):
168
198
width = self .default_width ,
169
199
).to (default_device )
170
200
171
- with torch .no_grad ():
201
+ with torch .inference_mode ():
172
202
output = model (sample )
173
203
restored_output = restored_model (sample )
174
204
@@ -197,7 +227,7 @@ def test_preserve_forward_output(self):
197
227
output_tensor = torch .load (output_tensor_path , weights_only = True )
198
228
output_tensor = output_tensor .to (default_device )
199
229
200
- with torch .no_grad ():
230
+ with torch .inference_mode ():
201
231
output = model (input_tensor )
202
232
203
233
self .assertEqual (output .shape , output_tensor .shape )
0 commit comments