Skip to content

Load model with mismatched sizes #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/save_load.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ For example:
# Alternatively, load the model directly from the Hugging Face Hub
model = smp.from_pretrained('username/my-model')
Loading pre-trained model with different number of classes for fine-tuning:

.. code:: python
import segmentation_models_pytorch as smp
model = smp.from_pretrained('<path-or-repo-name>', classes=5, strict=False)
Saving model Metrics and Dataset Name
-------------------------------------

Expand Down
6 changes: 3 additions & 3 deletions examples/segformer_inference_pretrained.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
"metadata": {},
"outputs": [],
"source": [
"# fix for HF hub download\n",
"# see PR https://github.com/albumentations-team/albumentations/pull/2171\n",
"!pip install -U git+https://github.com/qubvel/albumentations@patch-2"
"# make sure you have the latest version of the libraries\n",
"!pip install -U segmentation-models-pytorch\n",
"!pip install albumentations matplotlib requests pillow"
]
},
{
Expand Down
61 changes: 42 additions & 19 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from typing import TypeVar, Type
import warnings

from typing import TypeVar, Type
from . import initialization as init
from .hub_mixin import SMPHubMixin
from .utils import is_torch_compiling
Expand Down Expand Up @@ -96,23 +97,45 @@
# timm- ported encoders with TimmUniversalEncoder
from segmentation_models_pytorch.encoders import TimmUniversalEncoder

if not isinstance(self.encoder, TimmUniversalEncoder):
return super().load_state_dict(state_dict, **kwargs)

patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]

is_deprecated_encoder = any(
self.encoder.name.startswith(pattern) for pattern in patterns
)

if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("encoder.") and not key.startswith("encoder.model."):
new_key = "encoder.model." + key.removeprefix("encoder.")
if "gernet" in self.encoder.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)
if isinstance(self.encoder, TimmUniversalEncoder):
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
is_deprecated_encoder = any(
self.encoder.name.startswith(pattern) for pattern in patterns
)
if is_deprecated_encoder:
keys = list(state_dict.keys())
for key in keys:
new_key = key
if key.startswith("encoder.") and not key.startswith(

Check warning on line 109 in segmentation_models_pytorch/base/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/model.py#L106-L109

Added lines #L106 - L109 were not covered by tests
"encoder.model."
):
new_key = "encoder.model." + key.removeprefix("encoder.")
if "gernet" in self.encoder.name:
new_key = new_key.replace(".stages.", ".stages_")
state_dict[new_key] = state_dict.pop(key)

Check warning on line 115 in segmentation_models_pytorch/base/model.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/base/model.py#L112-L115

Added lines #L112 - L115 were not covered by tests

# To be able to load weight with mismatched sizes
# We are going to filter mismatched sizes as well if strict=False
strict = kwargs.get("strict", True)
if not strict:
mismatched_keys = []
model_state_dict = self.state_dict()
common_keys = set(model_state_dict.keys()) & set(state_dict.keys())
for key in common_keys:
if model_state_dict[key].shape != state_dict[key].shape:
mismatched_keys.append(
(key, model_state_dict[key].shape, state_dict[key].shape)
)
state_dict.pop(key)

if mismatched_keys:
str_keys = "\n".join(
[
f" - {key}: {s} (weights) -> {m} (model)"
for key, m, s in mismatched_keys
]
)
text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n"
warnings.warn(text, stacklevel=-1)

return super().load_state_dict(state_dict, **kwargs)
36 changes: 36 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import tempfile
import segmentation_models_pytorch as smp

import pytest


def test_from_pretrained_with_mismatched_keys():
original_model = smp.Unet(classes=1)

with tempfile.TemporaryDirectory() as temp_dir:
original_model.save_pretrained(temp_dir)

# we should catch warning here and check if there specific keys there
with pytest.warns(UserWarning):
restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False)

assert restored_model.segmentation_head[0].out_channels == 2

# verify all the weight are the same expect mismatched ones
original_state_dict = original_model.state_dict()
restored_state_dict = restored_model.state_dict()

expected_mismatched_keys = [
"segmentation_head.0.weight",
"segmentation_head.0.bias",
]
mismatched_keys = []
for key in original_state_dict:
if key not in expected_mismatched_keys:
assert torch.allclose(original_state_dict[key], restored_state_dict[key])
else:
mismatched_keys.append(key)

assert len(mismatched_keys) == 2
assert sorted(mismatched_keys) == sorted(expected_mismatched_keys)
Loading