Skip to content

Commit 456871a

Browse files
qubveladamjstewart
andauthored
Fix torch compile, script, export (#1031)
* Move tests * Add compile test for encoders (to be optimized) * densnet * dpn * efficientnet * inceptionresnetv2 * inceptionv4 * mix-transformer * mobilenet * mobileone * resnet * senet * vgg * xception * Deprecate `timm-` encoders, remap to `tu-` most of them * Add tiny encoders and compile mark * Add conftest * Fix features * Add triggering compile tests on diff * Remove marks * Add test_compile stage to CI * Update requirements * Update makefile * Update get_stages * Fix weight loading for deprecate encoders * Fix weight loading for mobilenetv3 * Format * Add compile test for models * Add torch.export test * Disable export tests for dpn and inceptionv4 * Disable export for timm-eff-net * Huge fix for torch scripting (except Unet++ and UperNet) * Fix scripting * Add test for torch script * Add torch_script test to CI * Fix * Fix timm-effnet encoders * Make from_pretrained strict by default * Fix DeepLabV3 BC * Fix scripting for encoders * Refactor test do not skip * Fix encoders (mobilenet, inceptionv4) * Update encoders table * Fix export test * Fix docs * Update warning * Move pretrained settings * Add BC for timm- encoders * Fixing table * Update compile test * Change compile backend to eager * Update docs * Fixup * Fix batchnorm typo * Add depth validation * Update segmentation_models_pytorch/encoders/__init__.py Co-authored-by: Adam J. Stewart <[email protected]> * Style --------- Co-authored-by: Adam J. Stewart <[email protected]>
1 parent 93b19d3 commit 456871a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2418
-2289
lines changed

.github/workflows/tests.yml

+48-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
run: uv pip list
5252

5353
- name: Test with PyTest
54-
run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match"
54+
run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml --non-marked-only
5555

5656
- name: Upload coverage reports to Codecov
5757
uses: codecov/codecov-action@v5
@@ -73,7 +73,52 @@ jobs:
7373
- name: Show installed packages
7474
run: uv pip list
7575
- name: Test with PyTest
76-
run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -k "logits_match"
76+
run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -m "logits_match"
77+
78+
test_torch_compile:
79+
runs-on: ubuntu-latest
80+
steps:
81+
- uses: actions/checkout@v4
82+
- name: Set up Python
83+
uses: astral-sh/setup-uv@v5
84+
with:
85+
python-version: "3.10"
86+
- name: Install dependencies
87+
run: uv pip install -r requirements/required.txt -r requirements/test.txt
88+
- name: Show installed packages
89+
run: uv pip list
90+
- name: Test with PyTest
91+
run: uv run pytest -v -rsx -n 2 -m "compile"
92+
93+
test_torch_export:
94+
runs-on: ubuntu-latest
95+
steps:
96+
- uses: actions/checkout@v4
97+
- name: Set up Python
98+
uses: astral-sh/setup-uv@v5
99+
with:
100+
python-version: "3.10"
101+
- name: Install dependencies
102+
run: uv pip install -r requirements/required.txt -r requirements/test.txt
103+
- name: Show installed packages
104+
run: uv pip list
105+
- name: Test with PyTest
106+
run: uv run pytest -v -rsx -n 2 -m "torch_export"
107+
108+
test_torch_script:
109+
runs-on: ubuntu-latest
110+
steps:
111+
- uses: actions/checkout@v4
112+
- name: Set up Python
113+
uses: astral-sh/setup-uv@v5
114+
with:
115+
python-version: "3.10"
116+
- name: Install dependencies
117+
run: uv pip install -r requirements/required.txt -r requirements/test.txt
118+
- name: Show installed packages
119+
run: uv pip list
120+
- name: Test with PyTest
121+
run: uv run pytest -v -rsx -n 2 -m "torch_script"
77122

78123
minimum:
79124
runs-on: ubuntu-latest
@@ -88,4 +133,4 @@ jobs:
88133
- name: Show installed packages
89134
run: uv pip list
90135
- name: Test with pytest
91-
run: uv run pytest -v -rsx -n 2 -k "not logits_match"
136+
run: uv run pytest -v -rsx -n 2 --non-marked-only

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ install_dev: .venv
77
.venv/bin/pip install -e ".[test]"
88

99
test: .venv
10-
.venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match"
10+
.venv/bin/pytest -v -rsx -n 2 tests/ --non-marked-only
1111

1212
test_all: .venv
1313
RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/

docs/encoders.rst

+138-360
Large diffs are not rendered by default.

misc/generate_table.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
import os
12
import segmentation_models_pytorch as smp
23

4+
from tqdm import tqdm
5+
36
encoders = smp.encoders.encoders
47

58

69
WIDTH = 32
7-
COLUMNS = ["Encoder", "Weights", "Params, M"]
10+
COLUMNS = ["Encoder", "Pretrained weights", "Params, M", "Script", "Compile", "Export"]
11+
FILE = "encoders_table.md"
12+
13+
if os.path.exists(FILE):
14+
os.remove(FILE)
815

916

1017
def wrap_row(r):
@@ -16,18 +23,23 @@ def wrap_row(r):
1623
["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)
1724
)
1825

19-
print(wrap_row(header))
20-
print(wrap_row(separator))
26+
print(wrap_row(header), file=open(FILE, "a"))
27+
print(wrap_row(separator), file=open(FILE, "a"))
2128

22-
for encoder_name, encoder in encoders.items():
29+
for encoder_name, encoder in tqdm(encoders.items()):
2330
weights = "<br>".join(encoder["pretrained_settings"].keys())
24-
encoder_name = encoder_name.ljust(WIDTH, " ")
25-
weights = weights.ljust(WIDTH, " ")
2631

2732
model = encoder["encoder"](**encoder["params"], depth=5)
33+
34+
script = "✅" if model._is_torch_scriptable else "❌"
35+
compile = "✅" if model._is_torch_compilable else "❌"
36+
export = "✅" if model._is_torch_exportable else "❌"
37+
2838
params = sum(p.numel() for p in model.parameters())
2939
params = str(params // 1000000) + "M"
30-
params = params.ljust(WIDTH, " ")
3140

32-
row = "|".join([encoder_name, weights, params])
33-
print(wrap_row(row))
41+
row = [encoder_name, weights, params, script, compile, export]
42+
row = [str(r).ljust(WIDTH, " ") for r in row]
43+
row = "|".join(row)
44+
45+
print(wrap_row(row), file=open(FILE, "a"))

pyproject.toml

+5-11
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ docs = [
3939
'sphinx-book-theme',
4040
]
4141
test = [
42+
'gitpython',
4243
'packaging',
4344
'pytest',
4445
'pytest-cov',
4546
'pytest-xdist',
4647
'ruff>=0.9',
48+
'setuptools',
4749
]
4850

4951
[project.urls]
@@ -61,18 +63,10 @@ include = ['segmentation_models_pytorch*']
6163

6264
[tool.pytest.ini_options]
6365
markers = [
64-
"deeplabv3",
65-
"deeplabv3plus",
66-
"fpn",
67-
"linknet",
68-
"manet",
69-
"pan",
70-
"psp",
71-
"segformer",
72-
"unet",
73-
"unetplusplus",
74-
"upernet",
7566
"logits_match",
67+
"compile",
68+
"torch_export",
69+
"torch_script",
7670
]
7771

7872
[tool.coverage.run]

requirements/test.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
gitpython==3.1.44
12
packaging==24.2
23
pytest==8.3.4
34
pytest-xdist==3.6.1
45
pytest-cov==6.0.0
56
ruff==0.9.1
7+
setuptools==75.8.0

segmentation_models_pytorch/base/hub_mixin.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import json
23
from pathlib import Path
34
from typing import Optional, Union
@@ -114,12 +115,15 @@ def save_pretrained(
114115
return result
115116

116117
@property
118+
@torch.jit.unused
117119
def config(self) -> dict:
118120
return self._hub_mixin_config
119121

120122

121123
@wraps(PyTorchModelHubMixin.from_pretrained)
122-
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
124+
def from_pretrained(
125+
pretrained_model_name_or_path: str, *args, strict: bool = True, **kwargs
126+
):
123127
config_path = Path(pretrained_model_name_or_path) / "config.json"
124128
if not config_path.exists():
125129
config_path = hf_hub_download(
@@ -135,7 +139,9 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
135139
import segmentation_models_pytorch as smp
136140

137141
model_class = getattr(smp, model_class_name)
138-
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
142+
return model_class.from_pretrained(
143+
pretrained_model_name_or_path, *args, **kwargs, strict=strict
144+
)
139145

140146

141147
def supports_config_loading(func):

segmentation_models_pytorch/base/model.py

+39-4
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33

44
from . import initialization as init
55
from .hub_mixin import SMPHubMixin
6+
from .utils import is_torch_compiling
67

78
T = TypeVar("T", bound="SegmentationModel")
89

910

1011
class SegmentationModel(torch.nn.Module, SMPHubMixin):
1112
"""Base class for all segmentation models."""
1213

13-
# if model supports shape not divisible by 2 ^ n
14-
# set to False
14+
_is_torch_scriptable = True
15+
_is_torch_exportable = True
16+
_is_torch_compilable = True
17+
18+
# if model supports shape not divisible by 2 ^ n set to False
1519
requires_divisible_input_shape = True
1620

1721
# Fix type-hint for models, to avoid HubMixin signature
@@ -29,6 +33,9 @@ def check_input_shape(self, x):
2933
"""Check if the input shape is divisible by the output stride.
3034
If not, raise a RuntimeError.
3135
"""
36+
if not self.requires_divisible_input_shape:
37+
return
38+
3239
h, w = x.shape[-2:]
3340
output_stride = self.encoder.output_stride
3441
if h % output_stride != 0 or w % output_stride != 0:
@@ -50,11 +57,13 @@ def check_input_shape(self, x):
5057
def forward(self, x):
5158
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
5259

53-
if not torch.jit.is_tracing() and self.requires_divisible_input_shape:
60+
if not (
61+
torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling()
62+
):
5463
self.check_input_shape(x)
5564

5665
features = self.encoder(x)
57-
decoder_output = self.decoder(*features)
66+
decoder_output = self.decoder(features)
5867

5968
masks = self.segmentation_head(decoder_output)
6069

@@ -81,3 +90,29 @@ def predict(self, x):
8190
x = self.forward(x)
8291

8392
return x
93+
94+
def load_state_dict(self, state_dict, **kwargs):
95+
# for compatibility of weights for
96+
# timm- ported encoders with TimmUniversalEncoder
97+
from segmentation_models_pytorch.encoders import TimmUniversalEncoder
98+
99+
if not isinstance(self.encoder, TimmUniversalEncoder):
100+
return super().load_state_dict(state_dict, **kwargs)
101+
102+
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
103+
104+
is_deprecated_encoder = any(
105+
self.encoder.name.startswith(pattern) for pattern in patterns
106+
)
107+
108+
if is_deprecated_encoder:
109+
keys = list(state_dict.keys())
110+
for key in keys:
111+
new_key = key
112+
if key.startswith("encoder.") and not key.startswith("encoder.model."):
113+
new_key = "encoder.model." + key.removeprefix("encoder.")
114+
if "gernet" in self.encoder.name:
115+
new_key = new_key.replace(".stages.", ".stages_")
116+
state_dict[new_key] = state_dict.pop(key)
117+
118+
return super().load_state_dict(state_dict, **kwargs)
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
4+
@torch.jit.unused
5+
def is_torch_compiling():
6+
try:
7+
return torch.compiler.is_compiling()
8+
except Exception:
9+
try:
10+
import torch._dynamo as dynamo # noqa: F401
11+
12+
return dynamo.is_compiling()
13+
except Exception:
14+
return False

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

+22-18
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"""
3232

3333
from collections.abc import Iterable, Sequence
34-
from typing import Literal
34+
from typing import Literal, List
3535

3636
import torch
3737
from torch import nn
@@ -40,7 +40,7 @@
4040
__all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"]
4141

4242

43-
class DeepLabV3Decoder(nn.Sequential):
43+
class DeepLabV3Decoder(nn.Module):
4444
def __init__(
4545
self,
4646
in_channels: int,
@@ -49,21 +49,25 @@ def __init__(
4949
aspp_separable: bool,
5050
aspp_dropout: float,
5151
):
52-
super().__init__(
53-
ASPP(
54-
in_channels,
55-
out_channels,
56-
atrous_rates,
57-
separable=aspp_separable,
58-
dropout=aspp_dropout,
59-
),
60-
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
61-
nn.BatchNorm2d(out_channels),
62-
nn.ReLU(),
52+
super().__init__()
53+
self.aspp = ASPP(
54+
in_channels,
55+
out_channels,
56+
atrous_rates,
57+
separable=aspp_separable,
58+
dropout=aspp_dropout,
6359
)
60+
self.conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
61+
self.bn = nn.BatchNorm2d(out_channels)
62+
self.relu = nn.ReLU()
6463

65-
def forward(self, *features):
66-
return super().forward(features[-1])
64+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
65+
x = features[-1]
66+
x = self.aspp(x)
67+
x = self.conv(x)
68+
x = self.bn(x)
69+
x = self.relu(x)
70+
return x
6771

6872

6973
class DeepLabV3PlusDecoder(nn.Module):
@@ -124,7 +128,7 @@ def __init__(
124128
nn.ReLU(),
125129
)
126130

127-
def forward(self, *features):
131+
def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
128132
aspp_features = self.aspp(features[-1])
129133
aspp_features = self.up(aspp_features)
130134
high_res_features = self.block1(features[2])
@@ -174,7 +178,7 @@ def __init__(self, in_channels: int, out_channels: int):
174178
nn.ReLU(),
175179
)
176180

177-
def forward(self, x):
181+
def forward(self, x: torch.Tensor) -> torch.Tensor:
178182
size = x.shape[-2:]
179183
for mod in self:
180184
x = mod(x)
@@ -216,7 +220,7 @@ def __init__(
216220
nn.Dropout(dropout),
217221
)
218222

219-
def forward(self, x):
223+
def forward(self, x: torch.Tensor) -> torch.Tensor:
220224
res = []
221225
for conv in self.convs:
222226
res.append(conv(x))

0 commit comments

Comments
 (0)