|
9 | 9 |
|
10 | 10 | api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))
|
11 | 11 |
|
12 |
| -for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): |
13 |
| - model = model_class(encoder_name=ENCODER_NAME) |
14 |
| - model = model.eval() |
15 |
| - |
16 |
| - # generate test sample |
17 |
| - torch.manual_seed(423553) |
18 |
| - sample = torch.rand(1, 3, 256, 256) |
19 |
| - |
20 |
| - with torch.no_grad(): |
21 |
| - output = model(sample) |
22 | 12 |
|
| 13 | +def save_and_push(model, inputs, outputs, model_name, encoder_name): |
23 | 14 | with tempfile.TemporaryDirectory() as tmpdir:
|
24 | 15 | # save model
|
25 | 16 | model.save_pretrained(f"{tmpdir}")
|
26 | 17 |
|
27 | 18 | # save input and output
|
28 |
| - torch.save(sample, f"{tmpdir}/input-tensor.pth") |
29 |
| - torch.save(output, f"{tmpdir}/output-tensor.pth") |
| 19 | + torch.save(inputs, f"{tmpdir}/input-tensor.pth") |
| 20 | + torch.save(outputs, f"{tmpdir}/output-tensor.pth") |
30 | 21 |
|
31 | 22 | # create repo
|
32 |
| - repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}" |
| 23 | + repo_id = f"{HUB_REPO}/{model_name}-{encoder_name}" |
33 | 24 | if not api.repo_exists(repo_id=repo_id):
|
34 | 25 | api.create_repo(repo_id=repo_id, repo_type="model")
|
35 | 26 |
|
36 | 27 | # upload to hub
|
37 | 28 | api.upload_folder(
|
38 | 29 | folder_path=tmpdir,
|
39 |
| - repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}", |
| 30 | + repo_id=f"{HUB_REPO}/{model_name}-{encoder_name}", |
40 | 31 | repo_type="model",
|
41 | 32 | )
|
| 33 | + |
| 34 | + |
| 35 | +for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): |
| 36 | + if model_name == "dpt": |
| 37 | + encoder_name = "tu-test_vit" |
| 38 | + model = smp.DPT( |
| 39 | + encoder_name=encoder_name, |
| 40 | + decoder_readout="cat", |
| 41 | + decoder_intermediate_channels=(16, 32, 64, 64), |
| 42 | + decoder_fusion_channels=16, |
| 43 | + dynamic_img_size=True, |
| 44 | + ) |
| 45 | + else: |
| 46 | + encoder_name = ENCODER_NAME |
| 47 | + model = model_class(encoder_name=encoder_name) |
| 48 | + |
| 49 | + model = model.eval() |
| 50 | + |
| 51 | + # generate test sample |
| 52 | + torch.manual_seed(423553) |
| 53 | + sample = torch.rand(1, 3, 256, 256) |
| 54 | + |
| 55 | + with torch.no_grad(): |
| 56 | + output = model(sample) |
| 57 | + |
| 58 | + save_and_push(model, sample, output, model_name, encoder_name) |
0 commit comments