-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Adding DPT #1079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding DPT #1079
Changes from 9 commits
78ba0e8
2c38de6
5599409
c47bdfb
71e2acb
e85836d
35cb060
aa84f4e
67c4a75
85f22fb
ef48032
28204ad
1b9a6f6
334cfbb
7e1ef3b
d65c0f7
0a62fe0
e3238ae
df4d087
8bcb0ed
6ba6746
8fd8c77
9bf1fd2
0e9170f
a0aa5a8
6cfd3be
d4b162d
5fe80a5
0a14972
5b28978
0ed621c
6207310
19eeebe
f2e3f89
1257c4b
21a164a
8d3ed4f
4eb6ec3
165b9c0
5603707
9518964
38cb944
17d3328
83b9655
343fbe0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
qubvel marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should remove this file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|Encoder |Pretrained weights |Params, M |Script |Compile |Export | | ||
|--------------------------------|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:| |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,30 +17,68 @@ def has_dilation_support(name): | |
return False | ||
|
||
|
||
def valid_vit_encoder_for_dpt(name): | ||
if "vit" not in name: | ||
return False | ||
encoder = timm.create_model(name) | ||
feature_info = encoder.feature_info | ||
feature_info_obj = timm.models.FeatureInfo( | ||
feature_info=feature_info, out_indices=[0, 1, 2, 3] | ||
) | ||
reduction_scales = list(feature_info_obj.reduction()) | ||
|
||
if len(set(reduction_scales)) > 1: | ||
return False | ||
|
||
output_stride = reduction_scales[0] | ||
if bin(output_stride).count("1") != 1: | ||
return False | ||
|
||
return True | ||
|
||
|
||
def make_table(data): | ||
names = data.keys() | ||
max_len1 = max([len(x) for x in names]) + 2 | ||
max_len2 = len("support dilation") + 2 | ||
max_len3 = len("Supported for DPT") + 2 | ||
|
||
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n" | ||
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n" | ||
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n" | ||
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n" | ||
top = ( | ||
"| " | ||
+ "Encoder name".ljust(max_len1 - 2) | ||
+ " | " | ||
+ "Support dilation".center(max_len2 - 2) | ||
+ " | " | ||
+ "Supported for DPT".center(max_len3 - 2) | ||
+ " |\n" | ||
) | ||
|
||
table = l1 + top + l2 | ||
|
||
for k in sorted(data.keys()): | ||
support = ( | ||
"✅".center(max_len2 - 3) | ||
if data[k]["has_dilation"] | ||
else " ".center(max_len2 - 2) | ||
if "has_dilation" in data[k] and data[k]["has_dilation"]: | ||
support = "✅".center(max_len2 - 3) | ||
|
||
else: | ||
support = " ".center(max_len2 - 2) | ||
|
||
if "supported_only_for_dpt" in data[k]: | ||
supported_for_dpt = "✅".center(max_len3 - 3) | ||
|
||
else: | ||
supported_for_dpt = " ".center(max_len3 - 2) | ||
|
||
table += ( | ||
"| " | ||
+ k.ljust(max_len1 - 2) | ||
+ " | " | ||
+ support | ||
+ " | " | ||
+ supported_for_dpt | ||
+ " |\n" | ||
) | ||
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n" | ||
table += l1 | ||
|
||
return table | ||
|
@@ -55,8 +93,13 @@ def make_table(data): | |
check_features_and_reduction(name) | ||
has_dilation = has_dilation_support(name) | ||
supported_models[name] = dict(has_dilation=has_dilation) | ||
|
||
except Exception: | ||
continue | ||
try: | ||
if valid_vit_encoder_for_dpt(name): | ||
supported_models[name] = dict(supported_only_for_dpt=True) | ||
except Exception: | ||
continue | ||
Comment on lines
+96
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we check only if we got an exception here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you check the behaviour of functions
In short, a model which satisfies the conditions specified by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I suppose this code should be updated as well, because as far as I remember [4, 8, 16, 32] and [1, 2, 4, 8, 16, 32] reductions are also supported There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Should I update this as well or will you do it from your end? |
||
|
||
table = make_table(supported_models) | ||
print(table) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import segmentation_models_pytorch as smp | ||
import torch | ||
import huggingface_hub | ||
|
||
MODEL_WEIGHTS_PATH = r"C:\Users\vedan\Downloads\dpt_large-ade20k-b12dca68.pt" | ||
HF_HUB_PATH = "vedantdalimkar/DPT" | ||
|
||
if __name__ == "__main__": | ||
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150) | ||
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH) | ||
|
||
for layer_index in range(0, 4): | ||
for param in [ | ||
"running_mean", | ||
"running_var", | ||
"num_batches_tracked", | ||
"weight", | ||
"bias", | ||
]: | ||
for block_index in [1, 2]: | ||
for bn_index in [1, 2]: | ||
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model, | ||
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ... | ||
# and so on ... | ||
|
||
# This is because order of calling fusion layers is reversed in original DPT implementation | ||
|
||
dpt_model_dict[ | ||
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}" | ||
] = dpt_model_dict.pop( | ||
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}" | ||
) | ||
|
||
if param in ["weight", "bias"]: | ||
if param == "weight": | ||
for block_index in [1, 2]: | ||
for conv_index in [1, 2]: | ||
dpt_model_dict[ | ||
f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}" | ||
] = dpt_model_dict.pop( | ||
f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}" | ||
) | ||
|
||
dpt_model_dict[ | ||
f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}" | ||
] = dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}") | ||
|
||
dpt_model_dict[ | ||
f"decoder.fusion_blocks.{layer_index}.project.{param}" | ||
] = dpt_model_dict.pop( | ||
f"scratch.refinenet{4 - layer_index}.out_conv.{param}" | ||
) | ||
|
||
dpt_model_dict[ | ||
f"decoder.readout_blocks.{layer_index}.project.0.{param}" | ||
] = dpt_model_dict.pop( | ||
f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}" | ||
) | ||
|
||
dpt_model_dict[ | ||
f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}" | ||
] = dpt_model_dict.pop( | ||
f"pretrained.act_postprocess{layer_index + 1}.3.{param}" | ||
) | ||
|
||
if layer_index != 2: | ||
dpt_model_dict[ | ||
f"decoder.reassemble_blocks.{layer_index}.upsample.{param}" | ||
] = dpt_model_dict.pop( | ||
f"pretrained.act_postprocess{layer_index + 1}.4.{param}" | ||
) | ||
|
||
# Changing state dict keys for segmentation head | ||
dpt_model_dict = { | ||
( | ||
"segmentation_head.head" + name[len("scratch.output_conv") :] | ||
if name.startswith("scratch.output_conv") | ||
else name | ||
): parameter | ||
for name, parameter in dpt_model_dict.items() | ||
} | ||
|
||
# Changing state dict keys for encoder layers | ||
dpt_model_dict = { | ||
( | ||
"encoder.model" + name[len("pretrained.model") :] | ||
if name.startswith("pretrained.model") | ||
else name | ||
): parameter | ||
for name, parameter in dpt_model_dict.items() | ||
} | ||
|
||
# Removing keys,value pairs associated with auxiliary head | ||
dpt_model_dict = { | ||
name: parameter | ||
for name, parameter in dpt_model_dict.items() | ||
if not name.startswith("auxlayer") | ||
} | ||
|
||
smp_model.load_state_dict(dpt_model_dict, strict=True) | ||
|
||
model_name = MODEL_WEIGHTS_PATH.split("\\")[-1].replace(".pt", "") | ||
|
||
smp_model.save_pretrained(model_name) | ||
|
||
repo_id = HF_HUB_PATH | ||
api = huggingface_hub.HfApi() | ||
api.create_repo(repo_id=repo_id, repo_type="model") | ||
api.upload_folder(folder_path=model_name, repo_id=repo_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use smp_model.push_to_hub(...) instead |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .model import DPT | ||
|
||
__all__ = ["DPT"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use
.git/info/exclude
to ignore it locally