Skip to content

Adding DPT #1079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
78ba0e8
Initial timm vit encoder commit
vedantdalimkar Feb 28, 2025
2c38de6
Add DPT model and update logic for TimmViTEncoder class
vedantdalimkar Mar 2, 2025
5599409
Removed redudant documentation
vedantdalimkar Mar 2, 2025
c47bdfb
Added intitial test and some minor code modifications
vedantdalimkar Mar 5, 2025
71e2acb
Code refactor
vedantdalimkar Mar 8, 2025
e85836d
Added weight conversion script
vedantdalimkar Mar 22, 2025
35cb060
Moved conversion script to appropriate location
vedantdalimkar Mar 22, 2025
aa84f4e
Added logic in timm table generation for adding ViT encoders for DPT
Mar 22, 2025
67c4a75
Ruff formatting
vedantdalimkar Mar 22, 2025
85f22fb
Code revision
vedantdalimkar Mar 26, 2025
ef48032
Remove unnecessary comment
vedantdalimkar Mar 27, 2025
28204ad
Simplify ViT encoder
qubvel Apr 5, 2025
1b9a6f6
Refactor ProjectionReadout
qubvel Apr 5, 2025
334cfbb
Refactor modeling DPT
qubvel Apr 6, 2025
7e1ef3b
Support more encoders
qubvel Apr 6, 2025
d65c0f7
Refactor a bit conversion, added validation
qubvel Apr 6, 2025
0a62fe0
Fixup
qubvel Apr 6, 2025
e3238ae
Split forward for timm_vit
qubvel Apr 6, 2025
df4d087
Rename readout, remove feature_dim
qubvel Apr 6, 2025
8bcb0ed
refactor + add transform
qubvel Apr 6, 2025
6ba6746
Fixup
qubvel Apr 6, 2025
8fd8c77
Refine docs a bit
qubvel Apr 6, 2025
9bf1fd2
Refine docs
qubvel Apr 6, 2025
0e9170f
Refine model size a bit and docs
qubvel Apr 6, 2025
a0aa5a8
Add to docs
qubvel Apr 6, 2025
6cfd3be
Add note
qubvel Apr 6, 2025
d4b162d
Remove txt
qubvel Apr 6, 2025
5fe80a5
Fix doc
qubvel Apr 6, 2025
0a14972
Fix docstring
qubvel Apr 6, 2025
5b28978
Fixing list in activation
qubvel Apr 6, 2025
0ed621c
Fixing list
qubvel Apr 6, 2025
6207310
Fixing list
qubvel Apr 6, 2025
19eeebe
Fixup, fix type hint
qubvel Apr 6, 2025
f2e3f89
Merge branch 'main' into pr/vedantdalimkar/1079
qubvel Apr 6, 2025
1257c4b
Add to README
qubvel Apr 6, 2025
21a164a
Add example
qubvel Apr 6, 2025
8d3ed4f
Add decoder_readout according to initial impl
qubvel Apr 7, 2025
4eb6ec3
Tests update
vedantdalimkar Apr 7, 2025
165b9c0
Fix encoder tests
qubvel Apr 7, 2025
5603707
Fix DPT tests
qubvel Apr 7, 2025
9518964
Refactor a bit
qubvel Apr 7, 2025
38cb944
Tests
qubvel Apr 7, 2025
17d3328
Update gen test models
qubvel Apr 7, 2025
83b9655
Revert gitignore
qubvel Apr 7, 2025
343fbe0
Fix test
qubvel Apr 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
*ipynb*

# pyenv
.python-version
Expand Down Expand Up @@ -109,4 +110,7 @@ venv.bak/
.mypy_cache/

# ruff
.ruff_cache/
.ruff_cache/

# model weight folder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use .git/info/exclude to ignore it locally

Suggested change
# model weight folder
dpt_large-ade20k-b12dca68

dpt_large-ade20k-b12dca68
2 changes: 2 additions & 0 deletions encoders_table.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this file

Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
|Encoder |Pretrained weights |Params, M |Script |Compile |Export |
|--------------------------------|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|
59 changes: 51 additions & 8 deletions misc/generate_table_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,68 @@ def has_dilation_support(name):
return False


def valid_vit_encoder_for_dpt(name):
if "vit" not in name:
return False
encoder = timm.create_model(name)
feature_info = encoder.feature_info
feature_info_obj = timm.models.FeatureInfo(
feature_info=feature_info, out_indices=[0, 1, 2, 3]
)
reduction_scales = list(feature_info_obj.reduction())

if len(set(reduction_scales)) > 1:
return False

output_stride = reduction_scales[0]
if bin(output_stride).count("1") != 1:
return False

return True


def make_table(data):
names = data.keys()
max_len1 = max([len(x) for x in names]) + 2
max_len2 = len("support dilation") + 2
max_len3 = len("Supported for DPT") + 2

l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n"
top = (
"| "
+ "Encoder name".ljust(max_len1 - 2)
+ " | "
+ "Support dilation".center(max_len2 - 2)
+ " | "
+ "Supported for DPT".center(max_len3 - 2)
+ " |\n"
)

table = l1 + top + l2

for k in sorted(data.keys()):
support = (
"✅".center(max_len2 - 3)
if data[k]["has_dilation"]
else " ".center(max_len2 - 2)
if "has_dilation" in data[k] and data[k]["has_dilation"]:
support = "✅".center(max_len2 - 3)

else:
support = " ".center(max_len2 - 2)

if "supported_only_for_dpt" in data[k]:
supported_for_dpt = "✅".center(max_len3 - 3)

else:
supported_for_dpt = " ".center(max_len3 - 2)

table += (
"| "
+ k.ljust(max_len1 - 2)
+ " | "
+ support
+ " | "
+ supported_for_dpt
+ " |\n"
)
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
table += l1

return table
Expand All @@ -55,8 +93,13 @@ def make_table(data):
check_features_and_reduction(name)
has_dilation = has_dilation_support(name)
supported_models[name] = dict(has_dilation=has_dilation)

except Exception:
continue
try:
if valid_vit_encoder_for_dpt(name):
supported_models[name] = dict(supported_only_for_dpt=True)
except Exception:
continue
Comment on lines +96 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check only if we got an exception here?
Would it be better to make two independent checks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check the behaviour of functions check_features_and_reduction and valid_vit_encoder_for_dpt, their output is mutually exclusive. To be more detailed:

  1. check_features_and_reduction returns true only when reduction scales of a model are equal to [2, 4, 8, 16, 32], whereas,
  2. valid_vit_encoder_for_dpt returns false if the encoder has multiple reduction scales.

In short, a model which satisfies the conditions specified by check_features_and_reduction will never satisfy the conditions set by valid_vit_encoder_for_dpt and vice versa.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported

Should I update this as well or will you do it from your end?


table = make_table(supported_models)
print(table)
Expand Down
109 changes: 109 additions & 0 deletions scripts/models-conversions/dpt-original-to-smp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import segmentation_models_pytorch as smp
import torch
import huggingface_hub

MODEL_WEIGHTS_PATH = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt"
HF_HUB_PATH = "vedantdalimkar/DPT"

if __name__ == "__main__":
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150)
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH)

for layer_index in range(0, 4):
for param in [
"running_mean",
"running_var",
"num_batches_tracked",
"weight",
"bias",
]:
for block_index in [1, 2]:
for bn_index in [1, 2]:
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model,
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ...
# and so on ...

# This is because order of calling fusion layers is reversed in original DPT implementation

dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}"
)

if param in ["weight", "bias"]:
if param == "weight":
for block_index in [1, 2]:
for conv_index in [1, 2]:
dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}"
)

dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"
] = dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}")

dpt_model_dict[
f"decoder.fusion_blocks.{layer_index}.project.{param}"
] = dpt_model_dict.pop(
f"scratch.refinenet{4 - layer_index}.out_conv.{param}"
)

dpt_model_dict[
f"decoder.readout_blocks.{layer_index}.project.0.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}"
)

dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.3.{param}"
)

if layer_index != 2:
dpt_model_dict[
f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"
] = dpt_model_dict.pop(
f"pretrained.act_postprocess{layer_index + 1}.4.{param}"
)

# Changing state dict keys for segmentation head
dpt_model_dict = {
(
"segmentation_head.head" + name[len("scratch.output_conv") :]
if name.startswith("scratch.output_conv")
else name
): parameter
for name, parameter in dpt_model_dict.items()
}

# Changing state dict keys for encoder layers
dpt_model_dict = {
(
"encoder.model" + name[len("pretrained.model") :]
if name.startswith("pretrained.model")
else name
): parameter
for name, parameter in dpt_model_dict.items()
}

# Removing keys,value pairs associated with auxiliary head
dpt_model_dict = {
name: parameter
for name, parameter in dpt_model_dict.items()
if not name.startswith("auxlayer")
}

smp_model.load_state_dict(dpt_model_dict, strict=True)

model_name = MODEL_WEIGHTS_PATH.split("\\")[-1].replace(".pt", "")

smp_model.save_pretrained(model_name)

repo_id = HF_HUB_PATH
api = huggingface_hub.HfApi()
api.create_repo(repo_id=repo_id, repo_type="model")
api.upload_folder(folder_path=model_name, repo_id=repo_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use smp_model.push_to_hub(...) instead

3 changes: 3 additions & 0 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .decoders.pan import PAN
from .decoders.upernet import UPerNet
from .decoders.segformer import Segformer
from .decoders.dpt import DPT
from .base.hub_mixin import from_pretrained

from .__version__ import __version__
Expand All @@ -34,6 +35,7 @@
PAN,
UPerNet,
Segformer,
DPT,
]
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}

Expand Down Expand Up @@ -84,6 +86,7 @@ def create_model(
"PAN",
"UPerNet",
"Segformer",
"DPT",
"from_pretrained",
"create_model",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/dpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .model import DPT

__all__ = ["DPT"]
Loading
Loading