Skip to content

Commit e85836d

Browse files
Added weight conversion script
1 parent 71e2acb commit e85836d

File tree

6 files changed

+224
-277
lines changed

6 files changed

+224
-277
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ target/
7575

7676
# Jupyter Notebook
7777
.ipynb_checkpoints
78+
*ipynb*
7879

7980
# pyenv
8081
.python-version
@@ -109,4 +110,7 @@ venv.bak/
109110
.mypy_cache/
110111

111112
# ruff
112-
.ruff_cache/
113+
.ruff_cache/
114+
115+
# model weight folder
116+
dpt_large-ade20k-b12dca68

segmentation_models_pytorch/decoders/dpt/decoder.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22
import torch.nn as nn
3+
from segmentation_models_pytorch.base.modules import Activation
4+
from typing import Optional
35

46

57
def _get_feature_processing_out_channels(encoder_name: str) -> list[int]:
@@ -71,7 +73,7 @@ def forward(self, feature: torch.Tensor, cls_token: torch.Tensor):
7173
return feature
7274

7375

74-
class FeatureProcessBlock(nn.Module):
76+
class ReassembleBlock(nn.Module):
7577
"""
7678
Processes the features such that they have progressively increasing embedding size and progressively decreasing
7779
spatial dimension
@@ -107,7 +109,11 @@ def __init__(
107109
)
108110

109111
self.project_to_feature_dim = nn.Conv2d(
110-
in_channels=out_channel, out_channels=feature_dim, kernel_size=3, padding=1
112+
in_channels=out_channel,
113+
out_channels=feature_dim,
114+
kernel_size=3,
115+
padding=1,
116+
bias=False,
111117
)
112118

113119
def forward(self, x: torch.Tensor):
@@ -121,29 +127,34 @@ def forward(self, x: torch.Tensor):
121127
class ResidualConvBlock(nn.Module):
122128
def __init__(self, feature_dim: int):
123129
super().__init__()
124-
self.conv_block = nn.Sequential(
125-
nn.ReLU(),
126-
nn.Conv2d(
127-
in_channels=feature_dim,
128-
out_channels=feature_dim,
129-
kernel_size=3,
130-
padding=1,
131-
bias=False,
132-
),
133-
nn.BatchNorm2d(num_features=feature_dim),
134-
nn.ReLU(),
135-
nn.Conv2d(
136-
in_channels=feature_dim,
137-
out_channels=feature_dim,
138-
kernel_size=3,
139-
padding=1,
140-
bias=False,
141-
),
142-
nn.BatchNorm2d(num_features=feature_dim),
130+
131+
self.conv_1 = nn.Conv2d(
132+
in_channels=feature_dim,
133+
out_channels=feature_dim,
134+
kernel_size=3,
135+
padding=1,
136+
bias=False,
143137
)
138+
self.batch_norm_1 = nn.BatchNorm2d(num_features=feature_dim)
139+
self.conv_2 = nn.Conv2d(
140+
in_channels=feature_dim,
141+
out_channels=feature_dim,
142+
kernel_size=3,
143+
padding=1,
144+
bias=False,
145+
)
146+
self.batch_norm_2 = nn.BatchNorm2d(num_features=feature_dim)
147+
self.activation = nn.ReLU()
144148

145149
def forward(self, x: torch.Tensor):
146-
return x + self.conv_block(x)
150+
activated_x_1 = self.activation(x)
151+
conv_1_out = self.conv_1(activated_x_1)
152+
batch_norm_1_out = self.batch_norm_1(conv_1_out)
153+
activated_x_2 = self.activation(batch_norm_1_out)
154+
conv_2_out = self.conv_2(activated_x_2)
155+
batch_norm_2_out = self.batch_norm_2(conv_2_out)
156+
157+
return x + batch_norm_2_out
147158

148159

149160
class FusionBlock(nn.Module):
@@ -172,7 +183,6 @@ def forward(self, feature: torch.Tensor, preceding_layer_feature: torch.Tensor):
172183
feature, scale_factor=2, align_corners=True, mode="bilinear"
173184
)
174185
feature = self.project(feature)
175-
feature = self.activation(feature)
176186

177187
return feature
178188

@@ -230,9 +240,9 @@ def __init__(
230240
:encoder_depth
231241
]
232242

233-
self.feature_processing_blocks = nn.ModuleList(
243+
self.reassemble_blocks = nn.ModuleList(
234244
[
235-
FeatureProcessBlock(
245+
ReassembleBlock(
236246
transformer_embed_dim, feature_dim, out_channel, upsample_factor
237247
)
238248
for upsample_factor, out_channel in zip(
@@ -253,7 +263,7 @@ def forward(
253263
# Process the encoder features to scale of [1/32,1/16,1/8,1/4]
254264
for index, (feature, cls_token) in enumerate(zip(features, cls_tokens)):
255265
readout_feature = self.readout_blocks[index](feature, cls_token)
256-
processed_feature = self.feature_processing_blocks[index](readout_feature)
266+
processed_feature = self.reassemble_blocks[index](readout_feature)
257267
processed_features.append(processed_feature)
258268

259269
preceding_layer_feature = None
@@ -265,3 +275,38 @@ def forward(
265275
preceding_layer_feature = out
266276

267277
return out
278+
279+
280+
class DPTSegmentationHead(nn.Module):
281+
def __init__(
282+
self,
283+
in_channels: int,
284+
out_channels: int,
285+
activation: Optional[str] = None,
286+
kernel_size: int = 3,
287+
upsampling: float = 2.0,
288+
):
289+
super().__init__()
290+
291+
self.head = nn.Sequential(
292+
nn.Conv2d(
293+
in_channels, in_channels, kernel_size=kernel_size, padding=1, bias=False
294+
),
295+
nn.BatchNorm2d(in_channels),
296+
nn.ReLU(True),
297+
nn.Dropout(0.1, False),
298+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
299+
)
300+
self.activation = Activation(activation)
301+
self.upsampling_factor = upsampling
302+
303+
def forward(self, x):
304+
head_output = self.head(x)
305+
resized_output = nn.functional.interpolate(
306+
head_output,
307+
scale_factor=self.upsampling_factor,
308+
mode="bilinear",
309+
align_corners=True,
310+
)
311+
activation_output = self.activation(resized_output)
312+
return activation_output

segmentation_models_pytorch/decoders/dpt/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from segmentation_models_pytorch.encoders import get_encoder
1010
from segmentation_models_pytorch.base.utils import is_torch_compiling
1111
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
12-
from .decoder import DPTDecoder
12+
from .decoder import DPTDecoder, DPTSegmentationHead
1313

1414

1515
class DPT(SegmentationModel):
@@ -75,6 +75,7 @@ def __init__(
7575
classes: int = 1,
7676
activation: Optional[Union[str, Callable]] = None,
7777
aux_params: Optional[dict] = None,
78+
output_stride: Optional[int] = None,
7879
**kwargs: dict[str, Any],
7980
):
8081
super().__init__()
@@ -86,6 +87,7 @@ def __init__(
8687
weights=encoder_weights,
8788
use_vit_encoder=True,
8889
allow_downsampling=False,
90+
output_stride=output_stride,
8991
allow_output_stride_not_power_of_two=False,
9092
**kwargs,
9193
)
@@ -103,11 +105,11 @@ def __init__(
103105
cls_token_supported=self.cls_token_supported,
104106
)
105107

106-
self.segmentation_head = SegmentationHead(
108+
self.segmentation_head = DPTSegmentationHead(
107109
in_channels=feature_dim,
108110
out_channels=classes,
109111
activation=activation,
110-
kernel_size=1,
112+
kernel_size=3,
111113
upsampling=2,
112114
)
113115

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import segmentation_models_pytorch as smp
2+
import torch
3+
import huggingface_hub
4+
5+
MODEL_WEIGHTS_PATH = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt"
6+
HF_HUB_PATH = "vedantdalimkar/DPT"
7+
8+
if __name__ == "__main__":
9+
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150)
10+
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH)
11+
12+
for layer_index in range(0, 4):
13+
for param in [
14+
"running_mean",
15+
"running_var",
16+
"num_batches_tracked",
17+
"weight",
18+
"bias",
19+
]:
20+
for block_index in [1, 2]:
21+
for bn_index in [1, 2]:
22+
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model,
23+
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ...
24+
# and so on ...
25+
26+
# This is because order of calling fusion layers is reversed in original DPT implementation
27+
28+
dpt_model_dict[
29+
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"
30+
] = dpt_model_dict.pop(
31+
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}"
32+
)
33+
34+
if param in ["weight", "bias"]:
35+
if param == "weight":
36+
for block_index in [1, 2]:
37+
for conv_index in [1, 2]:
38+
dpt_model_dict[
39+
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"
40+
] = dpt_model_dict.pop(
41+
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}"
42+
)
43+
44+
dpt_model_dict[
45+
f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"
46+
] = dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}")
47+
48+
dpt_model_dict[
49+
f"decoder.fusion_blocks.{layer_index}.project.{param}"
50+
] = dpt_model_dict.pop(
51+
f"scratch.refinenet{4 - layer_index}.out_conv.{param}"
52+
)
53+
54+
dpt_model_dict[
55+
f"decoder.readout_blocks.{layer_index}.project.0.{param}"
56+
] = dpt_model_dict.pop(
57+
f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}"
58+
)
59+
60+
dpt_model_dict[
61+
f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"
62+
] = dpt_model_dict.pop(
63+
f"pretrained.act_postprocess{layer_index + 1}.3.{param}"
64+
)
65+
66+
if layer_index != 2:
67+
dpt_model_dict[
68+
f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"
69+
] = dpt_model_dict.pop(
70+
f"pretrained.act_postprocess{layer_index + 1}.4.{param}"
71+
)
72+
73+
# Changing state dict keys for segmentation head
74+
dpt_model_dict = {
75+
(
76+
"segmentation_head.head" + name[len("scratch.output_conv") :]
77+
if name.startswith("scratch.output_conv")
78+
else name
79+
): parameter
80+
for name, parameter in dpt_model_dict.items()
81+
}
82+
83+
# Changing state dict keys for encoder layers
84+
dpt_model_dict = {
85+
(
86+
"encoder.model" + name[len("pretrained.model") :]
87+
if name.startswith("pretrained.model")
88+
else name
89+
): parameter
90+
for name, parameter in dpt_model_dict.items()
91+
}
92+
93+
# Removing keys,value pairs associated with auxiliary head
94+
dpt_model_dict = {
95+
name: parameter
96+
for name, parameter in dpt_model_dict.items()
97+
if not name.startswith("auxlayer")
98+
}
99+
100+
smp_model.load_state_dict(dpt_model_dict, strict=True)
101+
102+
model_name = MODEL_WEIGHTS_PATH.split("\\")[-1].replace(".pt", "")
103+
104+
smp_model.save_pretrained(model_name)
105+
106+
repo_id = HF_HUB_PATH
107+
api = huggingface_hub.HfApi()
108+
api.create_repo(repo_id=repo_id, repo_type="model")
109+
api.upload_folder(folder_path=model_name, repo_id=repo_id)

segmentation_models_pytorch/encoders/timm_vit.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
import torch.nn as nn
66

7+
from .timm_universal import _merge_kwargs_no_duplicates
8+
79

810
class TimmViTEncoder(nn.Module):
911
"""
@@ -26,6 +28,7 @@ def __init__(
2628
in_channels: int = 3,
2729
depth: int = 4,
2830
output_indices: Optional[Union[list[int], int]] = None,
31+
output_stride: Optional[int] = None,
2932
**kwargs: dict[str, Any],
3033
):
3134
"""
@@ -49,7 +52,6 @@ def __init__(
4952
super().__init__()
5053
self.name = name
5154

52-
output_stride = kwargs.pop("output_stride", None)
5355
if output_stride is not None:
5456
raise ValueError("Dilated mode not supported, set output stride to None")
5557

@@ -160,6 +162,8 @@ def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tenso
160162

161163
cls_tokens = [None] * len(self.out_indices)
162164

165+
# If there are multiple prefix tokens, discard the register tokens if they are present and
166+
# return the CLS token, if it exists. Only patch features are retrieved if CLS token is not supported
163167
if self.num_prefix_tokens > 0:
164168
features, prefix_tokens = zip(*intermediate_outputs)
165169
if self.cls_token_supported:
@@ -205,42 +209,3 @@ def output_stride(self) -> int:
205209
int: The effective output stride.
206210
"""
207211
return self._output_stride
208-
209-
def load_state_dict(self, state_dict, **kwargs):
210-
# for compatibility of weights for
211-
# timm- ported encoders with TimmUniversalEncoder
212-
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
213-
214-
is_deprecated_encoder = any(
215-
self.name.startswith(pattern) for pattern in patterns
216-
)
217-
218-
if is_deprecated_encoder:
219-
keys = list(state_dict.keys())
220-
for key in keys:
221-
new_key = key
222-
if not key.startswith("model."):
223-
new_key = "model." + key
224-
if "gernet" in self.name:
225-
new_key = new_key.replace(".stages.", ".stages_")
226-
state_dict[new_key] = state_dict.pop(key)
227-
228-
return super().load_state_dict(state_dict, **kwargs)
229-
230-
231-
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
232-
"""
233-
Merge two dictionaries, ensuring no duplicate keys exist.
234-
235-
Args:
236-
a (dict): Base dictionary.
237-
b (dict): Additional parameters to merge.
238-
239-
Returns:
240-
dict: A merged dictionary.
241-
"""
242-
duplicates = a.keys() & b.keys()
243-
if duplicates:
244-
raise ValueError(f"'{duplicates}' already specified internally")
245-
246-
return a | b

0 commit comments

Comments
 (0)