Skip to content

Commit 165b9c0

Browse files
committed
Fix encoder tests
1 parent 4eb6ec3 commit 165b9c0

File tree

2 files changed

+63
-197
lines changed

2 files changed

+63
-197
lines changed

segmentation_models_pytorch/encoders/timm_vit.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ class TimmViTEncoder(nn.Module):
6161
- Ensures consistent multi-level feature extraction across all ViT models.
6262
"""
6363

64+
# prefix tokens are not supported for scripting
6465
_is_torch_scriptable = False
6566
_is_torch_exportable = True
66-
_is_torch_compilable = False
67+
_is_torch_compilable = True
6768

6869
def __init__(
6970
self,
@@ -87,10 +88,8 @@ def __init__(
8788
"""
8889
super().__init__()
8990

90-
if depth > 4 or depth < 1:
91-
raise ValueError(
92-
f"{self.__class__.__name__} depth should be in range [1, 4], got {depth}"
93-
)
91+
if depth < 1:
92+
raise ValueError(f"`encoder_depth` should be greater than 1, got {depth}.")
9493

9594
# Output stride validation needed for smp encoder test consistency
9695
output_stride = kwargs.pop("output_stride", None)
@@ -142,14 +141,10 @@ def __init__(
142141
self.output_stride = self.output_strides[-1]
143142
self.out_channels = [feature_info[i]["num_chs"] for i in output_indices]
144143
self.has_prefix_tokens = self._num_prefix_tokens > 0
145-
146-
@property
147-
def is_fixed_input_size(self) -> bool:
148-
return self.model.pretrained_cfg.get("fixed_input_size", False)
149-
150-
@property
151-
def input_size(self) -> int:
152-
return self.model.pretrained_cfg.get("input_size", None)
144+
self.input_size = self.model.pretrained_cfg.get("input_size", None)
145+
self.is_fixed_input_size = self.model.pretrained_cfg.get(
146+
"fixed_input_size", False
147+
)
153148

154149
def _forward_with_prefix_tokens(
155150
self, x: torch.Tensor

0 commit comments

Comments
 (0)