13
13
from __future__ import absolute_import
14
14
from urllib import request
15
15
import json
16
- from packaging .version import Version
16
+ from packaging .version import Version
17
+ from enum import Enum
18
+
19
+ class Tasks (str , Enum ):
20
+ """The ML task name as referenced in the infix of the model ID."""
21
+
22
+ IC = "ic"
23
+ OD = "od"
24
+ OD1 = "od1"
25
+ SEMSEG = "semseg"
26
+ IS = "is"
27
+ TC = "tc"
28
+ SPC = "spc"
29
+ EQA = "eqa"
30
+ TEXT_GENERATION = "textgeneration"
31
+ IC_EMBEDDING = "icembedding"
32
+ TC_EMBEDDING = "tcembedding"
33
+ NER = "ner"
34
+ SUMMARIZATION = "summarization"
35
+ TRANSLATION = "translation"
36
+ TABULAR_REGRESSION = "regression"
37
+ TABULAR_CLASSIFICATION = "classification"
38
+
39
+ class ProblemTypes (str , Enum ):
40
+ """Possible problem types for JumpStart models."""
41
+
42
+ IMAGE_CLASSIFICATION = "Image Classification"
43
+ IMAGE_EMBEDDING = "Image Embedding"
44
+ OBJECT_DETECTION = "Object Detection"
45
+ SEMANTIC_SEGMENTATION = "Semantic Segmentation"
46
+ INSTANCE_SEGMENTATION = "Instance Segmentation"
47
+ TEXT_CLASSIFICATION = "Text Classification"
48
+ TEXT_EMBEDDING = "Text Embedding"
49
+ QUESTION_ANSWERING = "Question Answering"
50
+ SENTENCE_PAIR_CLASSIFICATION = "Sentence Pair Classification"
51
+ TEXT_GENERATION = "Text Generation"
52
+ TEXT_SUMMARIZATION = "Text Summarization"
53
+ MACHINE_TRANSLATION = "Machine Translation"
54
+ NAMED_ENTITY_RECOGNITION = "Named Entity Recognition"
55
+ TABULAR_REGRESSION = "Regression"
56
+ TABULAR_CLASSIFICATION = "Classification"
17
57
18
58
JUMPSTART_REGION = "eu-west-2"
19
59
SDK_MANIFEST_FILE = "models_manifest.json"
20
60
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
21
61
JUMPSTART_REGION , JUMPSTART_REGION
22
62
)
63
+ TASK_MAP = {
64
+ Tasks .IC : ProblemTypes .IMAGE_CLASSIFICATION ,
65
+ Tasks .IC_EMBEDDING : ProblemTypes .IMAGE_EMBEDDING ,
66
+ Tasks .OD : ProblemTypes .OBJECT_DETECTION ,
67
+ Tasks .OD1 : ProblemTypes .OBJECT_DETECTION ,
68
+ Tasks .SEMSEG : ProblemTypes .SEMANTIC_SEGMENTATION ,
69
+ Tasks .IS : ProblemTypes .INSTANCE_SEGMENTATION ,
70
+ Tasks .TC : ProblemTypes .TEXT_CLASSIFICATION ,
71
+ Tasks .TC_EMBEDDING : ProblemTypes .TEXT_EMBEDDING ,
72
+ Tasks .EQA : ProblemTypes .QUESTION_ANSWERING ,
73
+ Tasks .SPC : ProblemTypes .SENTENCE_PAIR_CLASSIFICATION ,
74
+ Tasks .TEXT_GENERATION : ProblemTypes .TEXT_GENERATION ,
75
+ Tasks .SUMMARIZATION : ProblemTypes .TEXT_SUMMARIZATION ,
76
+ Tasks .TRANSLATION : ProblemTypes .MACHINE_TRANSLATION ,
77
+ Tasks .NER : ProblemTypes .NAMED_ENTITY_RECOGNITION ,
78
+ Tasks .TABULAR_REGRESSION : ProblemTypes .TABULAR_REGRESSION ,
79
+ Tasks .TABULAR_CLASSIFICATION : ProblemTypes .TABULAR_CLASSIFICATION ,
80
+ }
23
81
24
82
25
83
def get_jumpstart_sdk_manifest ():
@@ -35,6 +93,10 @@ def get_jumpstart_sdk_spec(key):
35
93
model_spec = f .read ().decode ("utf-8" )
36
94
return json .loads (model_spec )
37
95
96
+ def get_model_task (id ):
97
+ task_short = id .split ('-' )[1 ]
98
+ return TASK_MAP [task_short ] if task_short in TASK_MAP else 'Source'
99
+
38
100
39
101
def create_jumpstart_model_table ():
40
102
sdk_manifest = get_jumpstart_sdk_manifest ()
@@ -69,26 +131,29 @@ def create_jumpstart_model_table():
69
131
)
70
132
file_content .append (
71
133
"""
72
- Each model id is linked to an external page that describes the model.\n
134
+ Click on the Problem Type to navigate to the source of the model.\n
73
135
"""
74
136
)
75
137
file_content .append ("\n " )
76
138
file_content .append (".. list-table:: Available Models\n " )
77
- file_content .append (" :widths: 50 20 20 20\n " )
139
+ file_content .append (" :widths: 50 20 20 20 30 \n " )
78
140
file_content .append (" :header-rows: 1\n " )
79
141
file_content .append (" :class: datatable\n " )
80
142
file_content .append ("\n " )
81
143
file_content .append (" * - Model ID\n " )
82
144
file_content .append (" - Fine Tunable?\n " )
83
145
file_content .append (" - Latest Version\n " )
84
146
file_content .append (" - Min SDK Version\n " )
147
+ file_content .append (" - Problem Type/Source\n " )
85
148
86
149
for model in sdk_manifest_top_versions_for_models .values ():
87
150
model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
88
- file_content .append (" * - `{} <{}>`_\n " .format (model_spec ["model_id" ], model_spec ["url" ]))
151
+ model_task = get_model_task (model_spec ["model_id" ])
152
+ file_content .append (" * - {}\n " .format (model_spec ["model_id" ]))
89
153
file_content .append (" - {}\n " .format (model_spec ["training_supported" ]))
90
154
file_content .append (" - {}\n " .format (model ["version" ]))
91
155
file_content .append (" - {}\n " .format (model ["min_version" ]))
156
+ file_content .append (" - `{} <{}>`__\n " .format (model_task , model_spec ["url" ]))
92
157
93
158
f = open ("doc_utils/jumpstart.rst" , "w" )
94
159
f .writelines (file_content )
0 commit comments