@@ -15,32 +15,62 @@ def has_dilation_support(name):
15
15
return True
16
16
except Exception :
17
17
return False
18
+
19
+ def valid_vit_encoder_for_dpt (name ):
20
+ if "vit" not in name :
21
+ return False
22
+ encoder = timm .create_model (name )
23
+ feature_info = encoder .feature_info
24
+ feature_info_obj = timm .models .FeatureInfo (
25
+ feature_info = feature_info , out_indices = [0 ,1 ,2 ,3 ]
26
+ )
27
+ reduction_scales = list (feature_info_obj .reduction ())
28
+
29
+ if len (set (reduction_scales )) > 1 :
30
+ return False
31
+
32
+ output_stride = reduction_scales [0 ]
33
+ if bin (output_stride ).count ("1" ) != 1 :
34
+ return False
35
+
36
+ return True
18
37
19
38
20
39
def make_table (data ):
21
40
names = data .keys ()
22
41
max_len1 = max ([len (x ) for x in names ]) + 2
23
42
max_len2 = len ("support dilation" ) + 2
43
+ max_len3 = len ("Supported for DPT" ) + 2
24
44
25
- l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n "
26
- l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n "
45
+ l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+ \n "
46
+ l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+ \n "
27
47
top = (
28
48
"| "
29
49
+ "Encoder name" .ljust (max_len1 - 2 )
30
50
+ " | "
31
51
+ "Support dilation" .center (max_len2 - 2 )
52
+ + " | "
53
+ + "Supported for DPT" .center (max_len3 - 2 )
32
54
+ " |\n "
33
55
)
34
56
35
57
table = l1 + top + l2
36
58
37
59
for k in sorted (data .keys ()):
38
- support = (
39
- "✅" .center (max_len2 - 3 )
40
- if data [k ]["has_dilation" ]
41
- else " " .center (max_len2 - 2 )
42
- )
43
- table += "| " + k .ljust (max_len1 - 2 ) + " | " + support + " |\n "
60
+
61
+ if "has_dilation" in data [k ] and data [k ]["has_dilation" ]:
62
+ support = ("✅" .center (max_len2 - 3 ))
63
+
64
+ else :
65
+ support = (" " .center (max_len2 - 2 ))
66
+
67
+ if "supported_only_for_dpt" in data [k ]:
68
+ supported_for_dpt = ("✅" .center (max_len3 - 3 ))
69
+
70
+ else :
71
+ supported_for_dpt = (" " .center (max_len3 - 2 ))
72
+
73
+ table += "| " + k .ljust (max_len1 - 2 ) + " | " + support + " | " + supported_for_dpt + " |\n "
44
74
table += l1
45
75
46
76
return table
@@ -55,8 +85,15 @@ def make_table(data):
55
85
check_features_and_reduction (name )
56
86
has_dilation = has_dilation_support (name )
57
87
supported_models [name ] = dict (has_dilation = has_dilation )
88
+
58
89
except Exception :
59
- continue
90
+ try :
91
+ if valid_vit_encoder_for_dpt (name ):
92
+ supported_models [name ] = dict (supported_only_for_dpt = True )
93
+ except :
94
+ continue
95
+
96
+
60
97
61
98
table = make_table (supported_models )
62
99
print (table )
0 commit comments