From 22af91776a969a6f65ed65a005766a98c8676ca0 Mon Sep 17 00:00:00 2001 From: qubvel Date: Wed, 26 Mar 2025 17:29:35 +0000 Subject: [PATCH 1/5] Add a way to load model with mismatched sizes --- segmentation_models_pytorch/base/model.py | 61 ++++++++++++++++------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 29820840..67f9422a 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -1,6 +1,7 @@ import torch -from typing import TypeVar, Type +import warnings +from typing import TypeVar, Type from . import initialization as init from .hub_mixin import SMPHubMixin from .utils import is_torch_compiling @@ -96,23 +97,45 @@ def load_state_dict(self, state_dict, **kwargs): # timm- ported encoders with TimmUniversalEncoder from segmentation_models_pytorch.encoders import TimmUniversalEncoder - if not isinstance(self.encoder, TimmUniversalEncoder): - return super().load_state_dict(state_dict, **kwargs) - - patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] - - is_deprecated_encoder = any( - self.encoder.name.startswith(pattern) for pattern in patterns - ) - - if is_deprecated_encoder: - keys = list(state_dict.keys()) - for key in keys: - new_key = key - if key.startswith("encoder.") and not key.startswith("encoder.model."): - new_key = "encoder.model." + key.removeprefix("encoder.") - if "gernet" in self.encoder.name: - new_key = new_key.replace(".stages.", ".stages_") - state_dict[new_key] = state_dict.pop(key) + if isinstance(self.encoder, TimmUniversalEncoder): + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + is_deprecated_encoder = any( + self.encoder.name.startswith(pattern) for pattern in patterns + ) + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("encoder.") and not key.startswith( + "encoder.model." + ): + new_key = "encoder.model." + key.removeprefix("encoder.") + if "gernet" in self.encoder.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + # To be able to load weight with mismatched sizes + # We are going to filter mismatched sizes as well if strict=False + strict = kwargs.get("strict", True) + if not strict: + mismatched_keys = [] + model_state_dict = self.state_dict() + common_keys = set(model_state_dict.keys()) & set(state_dict.keys()) + for key in common_keys: + if model_state_dict[key].shape != state_dict[key].shape: + mismatched_keys.append( + (key, model_state_dict[key].shape, state_dict[key].shape) + ) + state_dict.pop(key) + + if mismatched_keys: + str_keys = "\n".join( + [ + f" - {key}: {s} (weights) -> {m} (model)" + for key, m, s in mismatched_keys + ] + ) + text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n" + warnings.warn(text, stacklevel=-1) return super().load_state_dict(state_dict, **kwargs) From 93a832ce10548c2b4e90e138b075848718629bba Mon Sep 17 00:00:00 2001 From: qubvel Date: Wed, 26 Mar 2025 17:29:49 +0000 Subject: [PATCH 2/5] Add test --- tests/test_base.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_base.py diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..41f21625 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,36 @@ +import torch +import tempfile +import segmentation_models_pytorch as smp + +import pytest + + +def test_from_pretrained_with_mismatched_keys(): + orginal_model = smp.Unet(classes=1) + + with tempfile.TemporaryDirectory() as temp_dir: + orginal_model.save_pretrained(temp_dir) + + # we should catch warning here and check if there specific keys there + with pytest.warns(UserWarning): + restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False) + + assert restored_model.segmentation_head[0].out_channels == 2 + + # verify all the weight are the same expect mismatched ones + original_state_dict = orginal_model.state_dict() + restored_state_dict = restored_model.state_dict() + + expected_mismatched_keys = [ + "segmentation_head.0.weight", + "segmentation_head.0.bias", + ] + mismatched_keys = [] + for key in original_state_dict: + if key not in expected_mismatched_keys: + assert torch.allclose(original_state_dict[key], restored_state_dict[key]) + else: + mismatched_keys.append(key) + + assert len(mismatched_keys) == 2 + assert sorted(mismatched_keys) == sorted(expected_mismatched_keys) From 28c672375bff52ed2815d6dd89f48de7176fb9fe Mon Sep 17 00:00:00 2001 From: qubvel Date: Wed, 26 Mar 2025 17:29:56 +0000 Subject: [PATCH 3/5] Update docs --- docs/save_load.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/save_load.rst b/docs/save_load.rst index e90e4eba..15434eb6 100644 --- a/docs/save_load.rst +++ b/docs/save_load.rst @@ -40,6 +40,14 @@ For example: # Alternatively, load the model directly from the Hugging Face Hub model = smp.from_pretrained('username/my-model') +Loading pre-trained model with different number of classes for fine-tuning: + +.. code:: python + + import segmentation_models_pytorch as smp + + model = smp.from_pretrained('', classes=5, strict=False) + Saving model Metrics and Dataset Name ------------------------------------- From 5fd265b976d4779628b584c37873d342ad9c9b49 Mon Sep 17 00:00:00 2001 From: qubvel Date: Wed, 26 Mar 2025 17:31:22 +0000 Subject: [PATCH 4/5] (unrelated) update packages in example --- examples/segformer_inference_pretrained.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/segformer_inference_pretrained.ipynb b/examples/segformer_inference_pretrained.ipynb index a0dda7d4..d2d195fd 100644 --- a/examples/segformer_inference_pretrained.ipynb +++ b/examples/segformer_inference_pretrained.ipynb @@ -13,9 +13,9 @@ "metadata": {}, "outputs": [], "source": [ - "# fix for HF hub download\n", - "# see PR https://github.com/albumentations-team/albumentations/pull/2171\n", - "!pip install -U git+https://github.com/qubvel/albumentations@patch-2" + "# make sure you have the latest version of the libraries\n", + "!pip install -U segmentation-models-pytorch\n", + "!pip install albumentations matplotlib requests pillow" ] }, { From d43109db3f085b037232364c6c9291800e7f6299 Mon Sep 17 00:00:00 2001 From: qubvel Date: Wed, 26 Mar 2025 17:42:43 +0000 Subject: [PATCH 5/5] Fix typo --- tests/test_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_base.py b/tests/test_base.py index 41f21625..1078c493 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -6,10 +6,10 @@ def test_from_pretrained_with_mismatched_keys(): - orginal_model = smp.Unet(classes=1) + original_model = smp.Unet(classes=1) with tempfile.TemporaryDirectory() as temp_dir: - orginal_model.save_pretrained(temp_dir) + original_model.save_pretrained(temp_dir) # we should catch warning here and check if there specific keys there with pytest.warns(UserWarning): @@ -18,7 +18,7 @@ def test_from_pretrained_with_mismatched_keys(): assert restored_model.segmentation_head[0].out_channels == 2 # verify all the weight are the same expect mismatched ones - original_state_dict = orginal_model.state_dict() + original_state_dict = original_model.state_dict() restored_state_dict = restored_model.state_dict() expected_mismatched_keys = [