@@ -61,9 +61,10 @@ class TimmViTEncoder(nn.Module):
61
61
- Ensures consistent multi-level feature extraction across all ViT models.
62
62
"""
63
63
64
+ # prefix tokens are not supported for scripting
64
65
_is_torch_scriptable = False
65
66
_is_torch_exportable = True
66
- _is_torch_compilable = False
67
+ _is_torch_compilable = True
67
68
68
69
def __init__ (
69
70
self ,
@@ -87,10 +88,8 @@ def __init__(
87
88
"""
88
89
super ().__init__ ()
89
90
90
- if depth > 4 or depth < 1 :
91
- raise ValueError (
92
- f"{ self .__class__ .__name__ } depth should be in range [1, 4], got { depth } "
93
- )
91
+ if depth < 1 :
92
+ raise ValueError (f"`encoder_depth` should be greater than 1, got { depth } ." )
94
93
95
94
# Output stride validation needed for smp encoder test consistency
96
95
output_stride = kwargs .pop ("output_stride" , None )
@@ -142,14 +141,10 @@ def __init__(
142
141
self .output_stride = self .output_strides [- 1 ]
143
142
self .out_channels = [feature_info [i ]["num_chs" ] for i in output_indices ]
144
143
self .has_prefix_tokens = self ._num_prefix_tokens > 0
145
-
146
- @property
147
- def is_fixed_input_size (self ) -> bool :
148
- return self .model .pretrained_cfg .get ("fixed_input_size" , False )
149
-
150
- @property
151
- def input_size (self ) -> int :
152
- return self .model .pretrained_cfg .get ("input_size" , None )
144
+ self .input_size = self .model .pretrained_cfg .get ("input_size" , None )
145
+ self .is_fixed_input_size = self .model .pretrained_cfg .get (
146
+ "fixed_input_size" , False
147
+ )
153
148
154
149
def _forward_with_prefix_tokens (
155
150
self , x : torch .Tensor
0 commit comments