14
14
from urllib import request
15
15
import json
16
16
from packaging .version import Version
17
+ from enum import Enum
18
+
19
+
20
+ class Tasks (str , Enum ):
21
+ """The ML task name as referenced in the infix of the model ID."""
22
+
23
+ IC = "ic"
24
+ OD = "od"
25
+ OD1 = "od1"
26
+ SEMSEG = "semseg"
27
+ IS = "is"
28
+ TC = "tc"
29
+ SPC = "spc"
30
+ EQA = "eqa"
31
+ TEXT_GENERATION = "textgeneration"
32
+ IC_EMBEDDING = "icembedding"
33
+ TC_EMBEDDING = "tcembedding"
34
+ NER = "ner"
35
+ SUMMARIZATION = "summarization"
36
+ TRANSLATION = "translation"
37
+ TABULAR_REGRESSION = "regression"
38
+ TABULAR_CLASSIFICATION = "classification"
39
+
40
+
41
+ class ProblemTypes (str , Enum ):
42
+ """Possible problem types for JumpStart models."""
43
+
44
+ IMAGE_CLASSIFICATION = "Image Classification"
45
+ IMAGE_EMBEDDING = "Image Embedding"
46
+ OBJECT_DETECTION = "Object Detection"
47
+ SEMANTIC_SEGMENTATION = "Semantic Segmentation"
48
+ INSTANCE_SEGMENTATION = "Instance Segmentation"
49
+ TEXT_CLASSIFICATION = "Text Classification"
50
+ TEXT_EMBEDDING = "Text Embedding"
51
+ QUESTION_ANSWERING = "Question Answering"
52
+ SENTENCE_PAIR_CLASSIFICATION = "Sentence Pair Classification"
53
+ TEXT_GENERATION = "Text Generation"
54
+ TEXT_SUMMARIZATION = "Text Summarization"
55
+ MACHINE_TRANSLATION = "Machine Translation"
56
+ NAMED_ENTITY_RECOGNITION = "Named Entity Recognition"
57
+ TABULAR_REGRESSION = "Regression"
58
+ TABULAR_CLASSIFICATION = "Classification"
59
+
17
60
18
61
JUMPSTART_REGION = "eu-west-2"
19
62
SDK_MANIFEST_FILE = "models_manifest.json"
20
63
JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com" .format (
21
64
JUMPSTART_REGION , JUMPSTART_REGION
22
65
)
66
+ TASK_MAP = {
67
+ Tasks .IC : ProblemTypes .IMAGE_CLASSIFICATION ,
68
+ Tasks .IC_EMBEDDING : ProblemTypes .IMAGE_EMBEDDING ,
69
+ Tasks .OD : ProblemTypes .OBJECT_DETECTION ,
70
+ Tasks .OD1 : ProblemTypes .OBJECT_DETECTION ,
71
+ Tasks .SEMSEG : ProblemTypes .SEMANTIC_SEGMENTATION ,
72
+ Tasks .IS : ProblemTypes .INSTANCE_SEGMENTATION ,
73
+ Tasks .TC : ProblemTypes .TEXT_CLASSIFICATION ,
74
+ Tasks .TC_EMBEDDING : ProblemTypes .TEXT_EMBEDDING ,
75
+ Tasks .EQA : ProblemTypes .QUESTION_ANSWERING ,
76
+ Tasks .SPC : ProblemTypes .SENTENCE_PAIR_CLASSIFICATION ,
77
+ Tasks .TEXT_GENERATION : ProblemTypes .TEXT_GENERATION ,
78
+ Tasks .SUMMARIZATION : ProblemTypes .TEXT_SUMMARIZATION ,
79
+ Tasks .TRANSLATION : ProblemTypes .MACHINE_TRANSLATION ,
80
+ Tasks .NER : ProblemTypes .NAMED_ENTITY_RECOGNITION ,
81
+ Tasks .TABULAR_REGRESSION : ProblemTypes .TABULAR_REGRESSION ,
82
+ Tasks .TABULAR_CLASSIFICATION : ProblemTypes .TABULAR_CLASSIFICATION ,
83
+ }
23
84
24
85
25
86
def get_jumpstart_sdk_manifest ():
@@ -36,6 +97,11 @@ def get_jumpstart_sdk_spec(key):
36
97
return json .loads (model_spec )
37
98
38
99
100
+ def get_model_task (id ):
101
+ task_short = id .split ("-" )[1 ]
102
+ return TASK_MAP [task_short ] if task_short in TASK_MAP else "Source"
103
+
104
+
39
105
def create_jumpstart_model_table ():
40
106
sdk_manifest = get_jumpstart_sdk_manifest ()
41
107
sdk_manifest_top_versions_for_models = {}
@@ -69,26 +135,29 @@ def create_jumpstart_model_table():
69
135
)
70
136
file_content .append (
71
137
"""
72
- Each model id is linked to an external page that describes the model.\n
138
+ Click on the Problem Type to navigate to the source of the model.\n
73
139
"""
74
140
)
75
141
file_content .append ("\n " )
76
142
file_content .append (".. list-table:: Available Models\n " )
77
- file_content .append (" :widths: 50 20 20 20\n " )
143
+ file_content .append (" :widths: 50 20 20 20 30 \n " )
78
144
file_content .append (" :header-rows: 1\n " )
79
145
file_content .append (" :class: datatable\n " )
80
146
file_content .append ("\n " )
81
147
file_content .append (" * - Model ID\n " )
82
148
file_content .append (" - Fine Tunable?\n " )
83
149
file_content .append (" - Latest Version\n " )
84
150
file_content .append (" - Min SDK Version\n " )
151
+ file_content .append (" - Problem Type/Source\n " )
85
152
86
153
for model in sdk_manifest_top_versions_for_models .values ():
87
154
model_spec = get_jumpstart_sdk_spec (model ["spec_key" ])
88
- file_content .append (" * - `{} <{}>`_\n " .format (model_spec ["model_id" ], model_spec ["url" ]))
155
+ model_task = get_model_task (model_spec ["model_id" ])
156
+ file_content .append (" * - {}\n " .format (model_spec ["model_id" ]))
89
157
file_content .append (" - {}\n " .format (model_spec ["training_supported" ]))
90
158
file_content .append (" - {}\n " .format (model ["version" ]))
91
159
file_content .append (" - {}\n " .format (model ["min_version" ]))
160
+ file_content .append (" - `{} <{}>`__\n " .format (model_task , model_spec ["url" ]))
92
161
93
162
f = open ("doc_utils/jumpstart.rst" , "w" )
94
163
f .writelines (file_content )
0 commit comments