diff --git a/docs/save_load.rst b/docs/save_load.rst index e90e4eba..15434eb6 100644 --- a/docs/save_load.rst +++ b/docs/save_load.rst @@ -40,6 +40,14 @@ For example: # Alternatively, load the model directly from the Hugging Face Hub model = smp.from_pretrained('username/my-model') +Loading pre-trained model with different number of classes for fine-tuning: + +.. code:: python + + import segmentation_models_pytorch as smp + + model = smp.from_pretrained('', classes=5, strict=False) + Saving model Metrics and Dataset Name ------------------------------------- diff --git a/examples/segformer_inference_pretrained.ipynb b/examples/segformer_inference_pretrained.ipynb index a0dda7d4..d2d195fd 100644 --- a/examples/segformer_inference_pretrained.ipynb +++ b/examples/segformer_inference_pretrained.ipynb @@ -13,9 +13,9 @@ "metadata": {}, "outputs": [], "source": [ - "# fix for HF hub download\n", - "# see PR https://github.com/albumentations-team/albumentations/pull/2171\n", - "!pip install -U git+https://github.com/qubvel/albumentations@patch-2" + "# make sure you have the latest version of the libraries\n", + "!pip install -U segmentation-models-pytorch\n", + "!pip install albumentations matplotlib requests pillow" ] }, { diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 29820840..67f9422a 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -1,6 +1,7 @@ import torch -from typing import TypeVar, Type +import warnings +from typing import TypeVar, Type from . import initialization as init from .hub_mixin import SMPHubMixin from .utils import is_torch_compiling @@ -96,23 +97,45 @@ def load_state_dict(self, state_dict, **kwargs): # timm- ported encoders with TimmUniversalEncoder from segmentation_models_pytorch.encoders import TimmUniversalEncoder - if not isinstance(self.encoder, TimmUniversalEncoder): - return super().load_state_dict(state_dict, **kwargs) - - patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] - - is_deprecated_encoder = any( - self.encoder.name.startswith(pattern) for pattern in patterns - ) - - if is_deprecated_encoder: - keys = list(state_dict.keys()) - for key in keys: - new_key = key - if key.startswith("encoder.") and not key.startswith("encoder.model."): - new_key = "encoder.model." + key.removeprefix("encoder.") - if "gernet" in self.encoder.name: - new_key = new_key.replace(".stages.", ".stages_") - state_dict[new_key] = state_dict.pop(key) + if isinstance(self.encoder, TimmUniversalEncoder): + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + is_deprecated_encoder = any( + self.encoder.name.startswith(pattern) for pattern in patterns + ) + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("encoder.") and not key.startswith( + "encoder.model." + ): + new_key = "encoder.model." + key.removeprefix("encoder.") + if "gernet" in self.encoder.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + # To be able to load weight with mismatched sizes + # We are going to filter mismatched sizes as well if strict=False + strict = kwargs.get("strict", True) + if not strict: + mismatched_keys = [] + model_state_dict = self.state_dict() + common_keys = set(model_state_dict.keys()) & set(state_dict.keys()) + for key in common_keys: + if model_state_dict[key].shape != state_dict[key].shape: + mismatched_keys.append( + (key, model_state_dict[key].shape, state_dict[key].shape) + ) + state_dict.pop(key) + + if mismatched_keys: + str_keys = "\n".join( + [ + f" - {key}: {s} (weights) -> {m} (model)" + for key, m, s in mismatched_keys + ] + ) + text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n" + warnings.warn(text, stacklevel=-1) return super().load_state_dict(state_dict, **kwargs) diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..1078c493 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,36 @@ +import torch +import tempfile +import segmentation_models_pytorch as smp + +import pytest + + +def test_from_pretrained_with_mismatched_keys(): + original_model = smp.Unet(classes=1) + + with tempfile.TemporaryDirectory() as temp_dir: + original_model.save_pretrained(temp_dir) + + # we should catch warning here and check if there specific keys there + with pytest.warns(UserWarning): + restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False) + + assert restored_model.segmentation_head[0].out_channels == 2 + + # verify all the weight are the same expect mismatched ones + original_state_dict = original_model.state_dict() + restored_state_dict = restored_model.state_dict() + + expected_mismatched_keys = [ + "segmentation_head.0.weight", + "segmentation_head.0.bias", + ] + mismatched_keys = [] + for key in original_state_dict: + if key not in expected_mismatched_keys: + assert torch.allclose(original_state_dict[key], restored_state_dict[key]) + else: + mismatched_keys.append(key) + + assert len(mismatched_keys) == 2 + assert sorted(mismatched_keys) == sorted(expected_mismatched_keys)