File tree Expand file tree Collapse file tree 1 file changed +36
-0
lines changed Expand file tree Collapse file tree 1 file changed +36
-0
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+ import tempfile
3
+ import segmentation_models_pytorch as smp
4
+
5
+ import pytest
6
+
7
+
8
+ def test_from_pretrained_with_mismatched_keys ():
9
+ orginal_model = smp .Unet (classes = 1 )
10
+
11
+ with tempfile .TemporaryDirectory () as temp_dir :
12
+ orginal_model .save_pretrained (temp_dir )
13
+
14
+ # we should catch warning here and check if there specific keys there
15
+ with pytest .warns (UserWarning ):
16
+ restored_model = smp .from_pretrained (temp_dir , classes = 2 , strict = False )
17
+
18
+ assert restored_model .segmentation_head [0 ].out_channels == 2
19
+
20
+ # verify all the weight are the same expect mismatched ones
21
+ original_state_dict = orginal_model .state_dict ()
22
+ restored_state_dict = restored_model .state_dict ()
23
+
24
+ expected_mismatched_keys = [
25
+ "segmentation_head.0.weight" ,
26
+ "segmentation_head.0.bias" ,
27
+ ]
28
+ mismatched_keys = []
29
+ for key in original_state_dict :
30
+ if key not in expected_mismatched_keys :
31
+ assert torch .allclose (original_state_dict [key ], restored_state_dict [key ])
32
+ else :
33
+ mismatched_keys .append (key )
34
+
35
+ assert len (mismatched_keys ) == 2
36
+ assert sorted (mismatched_keys ) == sorted (expected_mismatched_keys )
You can’t perform that action at this time.
0 commit comments