1
- from typing import Any , Optional
1
+ from typing import Any , Optional , Union
2
2
3
3
import timm
4
4
import torch
@@ -15,17 +15,17 @@ class TimmViTEncoder(nn.Module):
15
15
- Ensures consistent multi-level feature extraction across all ViT models.
16
16
"""
17
17
18
- _is_torch_scriptable = True
18
+ _is_torch_scriptable = False
19
19
_is_torch_exportable = True
20
- _is_torch_compilable = True
20
+ _is_torch_compilable = False
21
21
22
22
def __init__ (
23
23
self ,
24
24
name : str ,
25
25
pretrained : bool = True ,
26
26
in_channels : int = 3 ,
27
27
depth : int = 4 ,
28
- output_indices : Optional [list [int ] | int ] = None ,
28
+ output_indices : Optional [Union [ list [int ], int ] ] = None ,
29
29
** kwargs : dict [str , Any ],
30
30
):
31
31
"""
@@ -49,16 +49,14 @@ def __init__(
49
49
super ().__init__ ()
50
50
self .name = name
51
51
52
- output_stride = kwargs .pop ("output_stride" ,None )
52
+ output_stride = kwargs .pop ("output_stride" , None )
53
53
if output_stride is not None :
54
- raise ValueError (
55
- "Dilated mode not supported, set output stride to None"
56
- )
54
+ raise ValueError ("Dilated mode not supported, set output stride to None" )
57
55
58
56
# Default model configuration for feature extraction
59
57
common_kwargs = dict (
60
58
in_chans = in_channels ,
61
- features_only = True ,
59
+ features_only = False ,
62
60
pretrained = pretrained ,
63
61
out_indices = tuple (range (depth )),
64
62
)
@@ -76,6 +74,23 @@ def __init__(
76
74
feature_info = tmp_model .feature_info
77
75
model_num_blocks = len (feature_info )
78
76
77
+ if output_indices is not None :
78
+ if isinstance (output_indices , int ):
79
+ output_indices = list (output_indices )
80
+
81
+ for output_index in output_indices :
82
+ if output_indices < 0 or output_indices > model_num_blocks :
83
+ raise ValueError (
84
+ f"Output indices for feature extraction should be greater than 0 and less \
85
+ than the number of blocks in the model ({ model_num_blocks } ), got { output_index } "
86
+ )
87
+
88
+ if len (output_indices ) != depth :
89
+ raise ValueError (
90
+ f"Length of output indices for feature extraction should be equal to the depth of the encoder\
91
+ architecture, got output indices length - { len (output_indices )} , encoder depth - { depth } "
92
+ )
93
+
79
94
if depth > model_num_blocks :
80
95
raise ValueError (
81
96
f"Depth of the encoder cannot exceed the number of blocks in the model \
@@ -87,9 +102,6 @@ def __init__(
87
102
int ((model_num_blocks / 4 ) * index ) - 1 for index in range (1 , depth + 1 )
88
103
]
89
104
90
- if isinstance (output_indices ,int ):
91
- output_indices = list (output_indices )
92
-
93
105
common_kwargs ["out_indices" ] = self .out_indices = output_indices
94
106
feature_info_obj = timm .models .FeatureInfo (
95
107
feature_info = feature_info , out_indices = output_indices
@@ -109,18 +121,16 @@ def __init__(
109
121
self ._output_stride = reduction_scales [0 ]
110
122
111
123
if (
112
- int (self ._output_stride ).bit_count ( ) != 1
124
+ bin (self ._output_stride ).count ( "1" ) != 1
113
125
and not allow_output_stride_not_power_of_two
114
126
):
115
127
raise ValueError (
116
128
f"Models with stride which is not a power of 2 are not supported, \
117
129
got output stride { self ._output_stride } "
118
130
)
119
131
120
- self .prefix_token_supported = getattr (tmp_model , "has_class_token" , False )
132
+ self .cls_token_supported = getattr (tmp_model , "has_class_token" , False )
121
133
self .num_prefix_tokens = getattr (tmp_model , "num_prefix_tokens" , 0 )
122
- if self .prefix_token_supported :
123
- common_kwargs ["features_only" ] = False
124
134
125
135
self .model = timm .create_model (
126
136
name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
@@ -131,47 +141,40 @@ def __init__(
131
141
self ._depth = depth
132
142
self ._embed_dim = tmp_model .embed_dim
133
143
134
- def forward (self , x : torch .Tensor ) -> list [list [torch .Tensor ], list [torch .Tensor ]]:
144
+ def forward (self , x : torch .Tensor ) -> tuple [list [torch .Tensor ], list [torch .Tensor ]]:
135
145
"""
136
146
Forward pass to extract multi-stage features.
137
147
138
148
Args:
139
149
x (torch.Tensor): Input tensor of shape (B, C, H, W).
140
150
141
151
Returns:
142
- list[torch.Tensor]: List of feature maps at different scales.
152
+ tuple[ list[torch.Tensor], list[torch.Tensor]]: Tuple of feature maps and cls tokens (if supported) at different scales.
143
153
"""
144
- if self .prefix_token_supported :
145
- intermediate_outputs = self .model .forward_intermediates (
146
- x ,
147
- indices = self .out_indices ,
148
- return_prefix_tokens = True ,
149
- intermediates_only = True ,
150
- )
151
- features , cls_tokens = zip (* intermediate_outputs )
152
-
153
- # Convert NHWC to NCHW if needed
154
- if self ._is_channel_last :
155
- features = [
156
- feature .permute (0 , 3 , 1 , 2 ).contiguous () for feature in features
157
- ]
158
-
159
- if self .num_prefix_tokens > 1 :
160
- cls_tokens = [cls_token [:, 0 , :] for cls_token in cls_tokens ]
154
+ intermediate_outputs = self .model .forward_intermediates (
155
+ x ,
156
+ indices = self .out_indices ,
157
+ return_prefix_tokens = True ,
158
+ intermediates_only = True ,
159
+ )
161
160
162
- return [ features , cls_tokens ]
161
+ cls_tokens = [ None ] * len ( self . out_indices )
163
162
164
- features = self .model (x )
163
+ if self .num_prefix_tokens > 0 :
164
+ features , prefix_tokens = zip (* intermediate_outputs )
165
+ if self .cls_token_supported :
166
+ if self .num_prefix_tokens == 1 :
167
+ cls_tokens = prefix_tokens
165
168
166
- # Convert NHWC to NCHW if needed
167
- if self ._is_channel_last :
168
- features = [
169
- feature .permute (0 , 3 , 1 , 2 ).contiguous () for feature in features
170
- ]
169
+ elif self .num_prefix_tokens > 1 :
170
+ cls_tokens = [
171
+ prefix_token [:, 0 , :] for prefix_token in prefix_tokens
172
+ ]
171
173
172
- cls_tokens = [None ] * len (features )
174
+ else :
175
+ features = intermediate_outputs
173
176
174
- return [ features , cls_tokens ]
177
+ return features , cls_tokens
175
178
176
179
@property
177
180
def embed_dim (self ) -> int :
0 commit comments