Skip to content

Commit 93a832c

Browse files
committed
Add test
1 parent 22af917 commit 93a832c

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/test_base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
import tempfile
3+
import segmentation_models_pytorch as smp
4+
5+
import pytest
6+
7+
8+
def test_from_pretrained_with_mismatched_keys():
9+
orginal_model = smp.Unet(classes=1)
10+
11+
with tempfile.TemporaryDirectory() as temp_dir:
12+
orginal_model.save_pretrained(temp_dir)
13+
14+
# we should catch warning here and check if there specific keys there
15+
with pytest.warns(UserWarning):
16+
restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False)
17+
18+
assert restored_model.segmentation_head[0].out_channels == 2
19+
20+
# verify all the weight are the same expect mismatched ones
21+
original_state_dict = orginal_model.state_dict()
22+
restored_state_dict = restored_model.state_dict()
23+
24+
expected_mismatched_keys = [
25+
"segmentation_head.0.weight",
26+
"segmentation_head.0.bias",
27+
]
28+
mismatched_keys = []
29+
for key in original_state_dict:
30+
if key not in expected_mismatched_keys:
31+
assert torch.allclose(original_state_dict[key], restored_state_dict[key])
32+
else:
33+
mismatched_keys.append(key)
34+
35+
assert len(mismatched_keys) == 2
36+
assert sorted(mismatched_keys) == sorted(expected_mismatched_keys)

0 commit comments

Comments
 (0)