@@ -15,24 +15,25 @@ def has_dilation_support(name):
15
15
return True
16
16
except Exception :
17
17
return False
18
-
18
+
19
+
19
20
def valid_vit_encoder_for_dpt (name ):
20
21
if "vit" not in name :
21
22
return False
22
23
encoder = timm .create_model (name )
23
24
feature_info = encoder .feature_info
24
25
feature_info_obj = timm .models .FeatureInfo (
25
- feature_info = feature_info , out_indices = [0 ,1 , 2 , 3 ]
26
- )
26
+ feature_info = feature_info , out_indices = [0 , 1 , 2 , 3 ]
27
+ )
27
28
reduction_scales = list (feature_info_obj .reduction ())
28
29
29
30
if len (set (reduction_scales )) > 1 :
30
31
return False
31
-
32
+
32
33
output_stride = reduction_scales [0 ]
33
34
if bin (output_stride ).count ("1" ) != 1 :
34
35
return False
35
-
36
+
36
37
return True
37
38
38
39
@@ -57,20 +58,27 @@ def make_table(data):
57
58
table = l1 + top + l2
58
59
59
60
for k in sorted (data .keys ()):
60
-
61
61
if "has_dilation" in data [k ] and data [k ]["has_dilation" ]:
62
- support = ( "✅" .center (max_len2 - 3 ) )
62
+ support = "✅" .center (max_len2 - 3 )
63
63
64
64
else :
65
- support = ( " " .center (max_len2 - 2 ) )
65
+ support = " " .center (max_len2 - 2 )
66
66
67
67
if "supported_only_for_dpt" in data [k ]:
68
- supported_for_dpt = ( "✅" .center (max_len3 - 3 ) )
68
+ supported_for_dpt = "✅" .center (max_len3 - 3 )
69
69
70
70
else :
71
- supported_for_dpt = (" " .center (max_len3 - 2 ))
72
-
73
- table += "| " + k .ljust (max_len1 - 2 ) + " | " + support + " | " + supported_for_dpt + " |\n "
71
+ supported_for_dpt = " " .center (max_len3 - 2 )
72
+
73
+ table += (
74
+ "| "
75
+ + k .ljust (max_len1 - 2 )
76
+ + " | "
77
+ + support
78
+ + " | "
79
+ + supported_for_dpt
80
+ + " |\n "
81
+ )
74
82
table += l1
75
83
76
84
return table
@@ -89,11 +97,9 @@ def make_table(data):
89
97
except Exception :
90
98
try :
91
99
if valid_vit_encoder_for_dpt (name ):
92
- supported_models [name ] = dict (supported_only_for_dpt = True )
93
- except :
100
+ supported_models [name ] = dict (supported_only_for_dpt = True )
101
+ except Exception :
94
102
continue
95
-
96
-
97
103
98
104
table = make_table (supported_models )
99
105
print (table )
0 commit comments