Skip to content

Commit 67c4a75

Browse files
Ruff formatting
1 parent aa84f4e commit 67c4a75

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

misc/generate_table_timm.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,25 @@ def has_dilation_support(name):
1515
return True
1616
except Exception:
1717
return False
18-
18+
19+
1920
def valid_vit_encoder_for_dpt(name):
2021
if "vit" not in name:
2122
return False
2223
encoder = timm.create_model(name)
2324
feature_info = encoder.feature_info
2425
feature_info_obj = timm.models.FeatureInfo(
25-
feature_info=feature_info, out_indices=[0,1,2,3]
26-
)
26+
feature_info=feature_info, out_indices=[0, 1, 2, 3]
27+
)
2728
reduction_scales = list(feature_info_obj.reduction())
2829

2930
if len(set(reduction_scales)) > 1:
3031
return False
31-
32+
3233
output_stride = reduction_scales[0]
3334
if bin(output_stride).count("1") != 1:
3435
return False
35-
36+
3637
return True
3738

3839

@@ -57,20 +58,27 @@ def make_table(data):
5758
table = l1 + top + l2
5859

5960
for k in sorted(data.keys()):
60-
6161
if "has_dilation" in data[k] and data[k]["has_dilation"]:
62-
support = ("✅".center(max_len2 - 3))
62+
support = "✅".center(max_len2 - 3)
6363

6464
else:
65-
support = (" ".center(max_len2 - 2))
65+
support = " ".center(max_len2 - 2)
6666

6767
if "supported_only_for_dpt" in data[k]:
68-
supported_for_dpt = ("✅".center(max_len3 - 3))
68+
supported_for_dpt = "✅".center(max_len3 - 3)
6969

7070
else:
71-
supported_for_dpt = (" ".center(max_len3 - 2))
72-
73-
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " | " + supported_for_dpt + " |\n"
71+
supported_for_dpt = " ".center(max_len3 - 2)
72+
73+
table += (
74+
"| "
75+
+ k.ljust(max_len1 - 2)
76+
+ " | "
77+
+ support
78+
+ " | "
79+
+ supported_for_dpt
80+
+ " |\n"
81+
)
7482
table += l1
7583

7684
return table
@@ -89,11 +97,9 @@ def make_table(data):
8997
except Exception:
9098
try:
9199
if valid_vit_encoder_for_dpt(name):
92-
supported_models[name] = dict(supported_only_for_dpt = True)
93-
except:
100+
supported_models[name] = dict(supported_only_for_dpt=True)
101+
except Exception:
94102
continue
95-
96-
97103

98104
table = make_table(supported_models)
99105
print(table)

tests/models/test_dpt.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
import os
21
import pytest
3-
import inspect
4-
import tempfile
5-
from functools import lru_cache
6-
from huggingface_hub import hf_hub_download
72
import torch
83
import segmentation_models_pytorch as smp
94

@@ -28,7 +23,7 @@ class TestDPTModel(base.BaseModelTester):
2823

2924
@property
3025
def hub_checkpoint(self):
31-
return f"vedantdalimkar/DPT"
26+
return "vedantdalimkar/DPT"
3227

3328
@pytest.mark.compile
3429
def test_compile(self):

0 commit comments

Comments
 (0)