Skip to content

Commit aa84f4e

Browse files
author
ved
committed
Added logic in timm table generation for adding ViT encoders for DPT
1 parent 35cb060 commit aa84f4e

File tree

2 files changed

+1520
-9
lines changed

2 files changed

+1520
-9
lines changed

misc/generate_table_timm.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,32 +15,62 @@ def has_dilation_support(name):
1515
return True
1616
except Exception:
1717
return False
18+
19+
def valid_vit_encoder_for_dpt(name):
20+
if "vit" not in name:
21+
return False
22+
encoder = timm.create_model(name)
23+
feature_info = encoder.feature_info
24+
feature_info_obj = timm.models.FeatureInfo(
25+
feature_info=feature_info, out_indices=[0,1,2,3]
26+
)
27+
reduction_scales = list(feature_info_obj.reduction())
28+
29+
if len(set(reduction_scales)) > 1:
30+
return False
31+
32+
output_stride = reduction_scales[0]
33+
if bin(output_stride).count("1") != 1:
34+
return False
35+
36+
return True
1837

1938

2039
def make_table(data):
2140
names = data.keys()
2241
max_len1 = max([len(x) for x in names]) + 2
2342
max_len2 = len("support dilation") + 2
43+
max_len3 = len("Supported for DPT") + 2
2444

25-
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
26-
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
45+
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n"
46+
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n"
2747
top = (
2848
"| "
2949
+ "Encoder name".ljust(max_len1 - 2)
3050
+ " | "
3151
+ "Support dilation".center(max_len2 - 2)
52+
+ " | "
53+
+ "Supported for DPT".center(max_len3 - 2)
3254
+ " |\n"
3355
)
3456

3557
table = l1 + top + l2
3658

3759
for k in sorted(data.keys()):
38-
support = (
39-
"✅".center(max_len2 - 3)
40-
if data[k]["has_dilation"]
41-
else " ".center(max_len2 - 2)
42-
)
43-
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
60+
61+
if "has_dilation" in data[k] and data[k]["has_dilation"]:
62+
support = ("✅".center(max_len2 - 3))
63+
64+
else:
65+
support = (" ".center(max_len2 - 2))
66+
67+
if "supported_only_for_dpt" in data[k]:
68+
supported_for_dpt = ("✅".center(max_len3 - 3))
69+
70+
else:
71+
supported_for_dpt = (" ".center(max_len3 - 2))
72+
73+
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " | " + supported_for_dpt + " |\n"
4474
table += l1
4575

4676
return table
@@ -55,8 +85,15 @@ def make_table(data):
5585
check_features_and_reduction(name)
5686
has_dilation = has_dilation_support(name)
5787
supported_models[name] = dict(has_dilation=has_dilation)
88+
5889
except Exception:
59-
continue
90+
try:
91+
if valid_vit_encoder_for_dpt(name):
92+
supported_models[name] = dict(supported_only_for_dpt = True)
93+
except:
94+
continue
95+
96+
6097

6198
table = make_table(supported_models)
6299
print(table)

0 commit comments

Comments
 (0)