Skip to content

Commit 17d3328

Browse files
committed
Update gen test models
1 parent 38cb944 commit 17d3328

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

misc/generate_test_models.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,50 @@
99

1010
api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))
1111

12-
for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
13-
model = model_class(encoder_name=ENCODER_NAME)
14-
model = model.eval()
15-
16-
# generate test sample
17-
torch.manual_seed(423553)
18-
sample = torch.rand(1, 3, 256, 256)
19-
20-
with torch.no_grad():
21-
output = model(sample)
2212

13+
def save_and_push(model, inputs, outputs, model_name, encoder_name):
2314
with tempfile.TemporaryDirectory() as tmpdir:
2415
# save model
2516
model.save_pretrained(f"{tmpdir}")
2617

2718
# save input and output
28-
torch.save(sample, f"{tmpdir}/input-tensor.pth")
29-
torch.save(output, f"{tmpdir}/output-tensor.pth")
19+
torch.save(inputs, f"{tmpdir}/input-tensor.pth")
20+
torch.save(outputs, f"{tmpdir}/output-tensor.pth")
3021

3122
# create repo
32-
repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}"
23+
repo_id = f"{HUB_REPO}/{model_name}-{encoder_name}"
3324
if not api.repo_exists(repo_id=repo_id):
3425
api.create_repo(repo_id=repo_id, repo_type="model")
3526

3627
# upload to hub
3728
api.upload_folder(
3829
folder_path=tmpdir,
39-
repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}",
30+
repo_id=f"{HUB_REPO}/{model_name}-{encoder_name}",
4031
repo_type="model",
4132
)
33+
34+
35+
for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
36+
if model_name == "dpt":
37+
encoder_name = "tu-test_vit"
38+
model = smp.DPT(
39+
encoder_name=encoder_name,
40+
decoder_readout="cat",
41+
decoder_intermediate_channels=(16, 32, 64, 64),
42+
decoder_fusion_channels=16,
43+
dynamic_img_size=True,
44+
)
45+
else:
46+
encoder_name = ENCODER_NAME
47+
model = model_class(encoder_name=encoder_name)
48+
49+
model = model.eval()
50+
51+
# generate test sample
52+
torch.manual_seed(423553)
53+
sample = torch.rand(1, 3, 256, 256)
54+
55+
with torch.no_grad():
56+
output = model(sample)
57+
58+
save_and_push(model, sample, output, model_name, encoder_name)

0 commit comments

Comments
 (0)