Skip to content

Commit 22af917

Browse files
committed
Add a way to load model with mismatched sizes
1 parent 4c7829b commit 22af917

File tree

1 file changed

+42
-19
lines changed
  • segmentation_models_pytorch/base

1 file changed

+42
-19
lines changed

segmentation_models_pytorch/base/model.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
2-
from typing import TypeVar, Type
2+
import warnings
33

4+
from typing import TypeVar, Type
45
from . import initialization as init
56
from .hub_mixin import SMPHubMixin
67
from .utils import is_torch_compiling
@@ -96,23 +97,45 @@ def load_state_dict(self, state_dict, **kwargs):
9697
# timm- ported encoders with TimmUniversalEncoder
9798
from segmentation_models_pytorch.encoders import TimmUniversalEncoder
9899

99-
if not isinstance(self.encoder, TimmUniversalEncoder):
100-
return super().load_state_dict(state_dict, **kwargs)
101-
102-
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
103-
104-
is_deprecated_encoder = any(
105-
self.encoder.name.startswith(pattern) for pattern in patterns
106-
)
107-
108-
if is_deprecated_encoder:
109-
keys = list(state_dict.keys())
110-
for key in keys:
111-
new_key = key
112-
if key.startswith("encoder.") and not key.startswith("encoder.model."):
113-
new_key = "encoder.model." + key.removeprefix("encoder.")
114-
if "gernet" in self.encoder.name:
115-
new_key = new_key.replace(".stages.", ".stages_")
116-
state_dict[new_key] = state_dict.pop(key)
100+
if isinstance(self.encoder, TimmUniversalEncoder):
101+
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
102+
is_deprecated_encoder = any(
103+
self.encoder.name.startswith(pattern) for pattern in patterns
104+
)
105+
if is_deprecated_encoder:
106+
keys = list(state_dict.keys())
107+
for key in keys:
108+
new_key = key
109+
if key.startswith("encoder.") and not key.startswith(
110+
"encoder.model."
111+
):
112+
new_key = "encoder.model." + key.removeprefix("encoder.")
113+
if "gernet" in self.encoder.name:
114+
new_key = new_key.replace(".stages.", ".stages_")
115+
state_dict[new_key] = state_dict.pop(key)
116+
117+
# To be able to load weight with mismatched sizes
118+
# We are going to filter mismatched sizes as well if strict=False
119+
strict = kwargs.get("strict", True)
120+
if not strict:
121+
mismatched_keys = []
122+
model_state_dict = self.state_dict()
123+
common_keys = set(model_state_dict.keys()) & set(state_dict.keys())
124+
for key in common_keys:
125+
if model_state_dict[key].shape != state_dict[key].shape:
126+
mismatched_keys.append(
127+
(key, model_state_dict[key].shape, state_dict[key].shape)
128+
)
129+
state_dict.pop(key)
130+
131+
if mismatched_keys:
132+
str_keys = "\n".join(
133+
[
134+
f" - {key}: {s} (weights) -> {m} (model)"
135+
for key, m, s in mismatched_keys
136+
]
137+
)
138+
text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n"
139+
warnings.warn(text, stacklevel=-1)
117140

118141
return super().load_state_dict(state_dict, **kwargs)

0 commit comments

Comments
 (0)